refactoring w/o labelled collection

This commit is contained in:
Alejandro Moreo Fernandez 2025-04-20 22:05:46 +02:00
parent c79b76516c
commit 075be93a23
22 changed files with 416 additions and 248 deletions

View File

@ -53,8 +53,8 @@ training, test = dataset.train_test
model = qp.method.aggregative.ACC() model = qp.method.aggregative.ACC()
model.fit(training) model.fit(training)
estim_prevalence = model.quantify(test.X) estim_prevalence = model.predict(test.X)
true_prevalence = test.prevalence() true_prevalence = test.prevalence()
error = qp.error.mae(true_prevalence, estim_prevalence) error = qp.error.mae(true_prevalence, estim_prevalence)
print(f'Mean Absolute Error (MAE)={error:.3f}') print(f'Mean Absolute Error (MAE)={error:.3f}')

View File

@ -1,3 +1,27 @@
Para quitar el labelledcollection de los métodos:
- El follón viene por la semántica confusa de fit en agregativos, que recibe 3 parámetros:
- data: LabelledCollection, que puede ser:
- el training set si hay que entrenar el clasificador
- None si no hay que entregar el clasificador
- el validation, que entra en conflicto con val_split, si no hay que entrenar clasificador
- fit_classifier: dice si hay que entrenar el clasificador o no, y estos cambia la semántica de los otros
- val_split: que puede ser:
- un número: el número de kfcv, lo cual implica fit_classifier=True y data=todo el training set
- una fración en [0,1]: que indica la parte que usamos para validation; implica fit_classifier=True y data=train+val
- un labelled collection: el conjunto de validación específico; no implica fit_classifier=True ni False
- La forma de quitar la dependencia de los métodos con LabelledCollection debería ser así:
- En el constructor se dice si el clasificador que se recibe por parámetro hay que entrenarlo o ya está entrenado;
es decir, hay un fit_classifier=True o False.
- fit_classifier=True:
- data en fit es todo el training incluyendo el validation y todo
- val_split:
- int: número de folds en kfcv
- proporción en [0,1]
- fit_classifier=False:
- [TODO] document confidence in manuals - [TODO] document confidence in manuals
- [TODO] Test the return_type="index" in protocols and finish the "distributing_samples.py" example - [TODO] Test the return_type="index" in protocols and finish the "distributing_samples.py" example
- [TODO] Add EDy (an implementation is available at quantificationlib) - [TODO] Add EDy (an implementation is available at quantificationlib)

View File

@ -32,8 +32,8 @@ dataset = qp.datasets.fetch_twitter('semeval16')
model = qp.method.aggregative.ACC(LogisticRegression()) model = qp.method.aggregative.ACC(LogisticRegression())
model.fit(dataset.training) model.fit(dataset.training)
estim_prevalence = model.quantify(dataset.test.instances) estim_prevalence = model.predict(dataset.test.instances)
true_prevalence = dataset.test.prevalence() true_prevalence = dataset.test.prevalence()
error = qp.error.mae(true_prevalence, estim_prevalence) error = qp.error.mae(true_prevalence, estim_prevalence)

View File

@ -132,7 +132,7 @@ svm = LinearSVC()
# (an alias is available in qp.method.aggregative.ClassifyAndCount) # (an alias is available in qp.method.aggregative.ClassifyAndCount)
model = qp.method.aggregative.CC(svm) model = qp.method.aggregative.CC(svm)
model.fit(training) model.fit(training)
estim_prevalence = model.quantify(test.instances) estim_prevalence = model.predict(test.instances)
``` ```
The same code could be used to instantiate an ACC, by simply replacing The same code could be used to instantiate an ACC, by simply replacing
@ -172,7 +172,7 @@ The following code illustrates the case in which PCC is used:
```python ```python
model = qp.method.aggregative.PCC(svm) model = qp.method.aggregative.PCC(svm)
model.fit(training) model.fit(training)
estim_prevalence = model.quantify(test.instances) estim_prevalence = model.predict(test.instances)
print('classifier:', model.classifier) print('classifier:', model.classifier)
``` ```
In this case, QuaPy will print: In this case, QuaPy will print:
@ -263,7 +263,7 @@ dataset = qp.datasets.fetch_twitter('hcr', pickle=True)
model = qp.method.aggregative.EMQ(LogisticRegression()) model = qp.method.aggregative.EMQ(LogisticRegression())
model.fit(dataset.training) model.fit(dataset.training)
estim_prevalence = model.quantify(dataset.test.instances) estim_prevalence = model.predict(dataset.test.instances)
``` ```
_New in v0.1.7_: EMQ now accepts two new parameters in the construction method, namely _New in v0.1.7_: EMQ now accepts two new parameters in the construction method, namely
@ -299,6 +299,7 @@ HDy was proposed as a binary classifier and the implementation
provided in QuaPy accepts only binary datasets. provided in QuaPy accepts only binary datasets.
The following code shows an example of use: The following code shows an example of use:
```python ```python
import quapy as qp import quapy as qp
from sklearn.linear_model import LogisticRegression from sklearn.linear_model import LogisticRegression
@ -309,7 +310,7 @@ qp.data.preprocessing.text2tfidf(dataset, min_df=5, inplace=True)
model = qp.method.aggregative.HDy(LogisticRegression()) model = qp.method.aggregative.HDy(LogisticRegression())
model.fit(dataset.training) model.fit(dataset.training)
estim_prevalence = model.quantify(dataset.test.instances) estim_prevalence = model.predict(dataset.test.instances)
``` ```
_New in v0.1.7:_ QuaPy now provides an implementation of the generalized _New in v0.1.7:_ QuaPy now provides an implementation of the generalized
@ -411,7 +412,7 @@ qp.environ['SVMPERF_HOME'] = '../svm_perf_quantification'
model = newOneVsAll(SVMQ(), n_jobs=-1) # run them on parallel model = newOneVsAll(SVMQ(), n_jobs=-1) # run them on parallel
model.fit(dataset.training) model.fit(dataset.training)
estim_prevalence = model.quantify(dataset.test.instances) estim_prevalence = model.predict(dataset.test.instances)
``` ```
Check the examples on [explicit_loss_minimization](https://github.com/HLT-ISTI/QuaPy/blob/devel/examples/5.explicit_loss_minimization.py) Check the examples on [explicit_loss_minimization](https://github.com/HLT-ISTI/QuaPy/blob/devel/examples/5.explicit_loss_minimization.py)
@ -531,7 +532,7 @@ dataset = qp.datasets.fetch_UCIBinaryDataset('haberman')
model = Ensemble(quantifier=ACC(LogisticRegression()), size=30, policy='ave', n_jobs=-1) model = Ensemble(quantifier=ACC(LogisticRegression()), size=30, policy='ave', n_jobs=-1)
model.fit(dataset.training) model.fit(dataset.training)
estim_prevalence = model.quantify(dataset.test.instances) estim_prevalence = model.predict(dataset.test.instances)
``` ```
Other aggregation policies implemented in QuaPy include: Other aggregation policies implemented in QuaPy include:
@ -579,6 +580,6 @@ learner = NeuralClassifierTrainer(cnn, device='cuda')
# train QuaNet # train QuaNet
model = QuaNet(learner, device='cuda') model = QuaNet(learner, device='cuda')
model.fit(dataset.training) model.fit(dataset.training)
estim_prevalence = model.quantify(dataset.test.instances) estim_prevalence = model.predict(dataset.test.instances)
``` ```

View File

@ -41,7 +41,7 @@ pacc.fit(train)
# let's now test our quantifier on the test data (of course, we should not use the test labels y at this point, only X) # let's now test our quantifier on the test data (of course, we should not use the test labels y at this point, only X)
X_test = test.X X_test = test.X
estim_prevalence = pacc.quantify(X_test) estim_prevalence = pacc.predict(X_test)
print(f'estimated test prevalence = {F.strprev(estim_prevalence)}') print(f'estimated test prevalence = {F.strprev(estim_prevalence)}')
print(f'true test prevalence = {F.strprev(test.prevalence())}') print(f'true test prevalence = {F.strprev(test.prevalence())}')

View File

@ -123,7 +123,7 @@ def _get_estimate(estimator_class, training: LabelledCollection, test: np.ndarra
"""Auxiliary method for running ACC and PACC.""" """Auxiliary method for running ACC and PACC."""
estimator = estimator_class(get_random_forest()) estimator = estimator_class(get_random_forest())
estimator.fit(training) estimator.fit(training)
return estimator.quantify(test) return estimator.predict(test)
def train_and_plot_bayesian_quantification(ax: plt.Axes, training: LabelledCollection, test: LabelledCollection) -> None: def train_and_plot_bayesian_quantification(ax: plt.Axes, training: LabelledCollection, test: LabelledCollection) -> None:
@ -133,7 +133,7 @@ def train_and_plot_bayesian_quantification(ax: plt.Axes, training: LabelledColle
quantifier.fit(training) quantifier.fit(training)
# Obtain mean prediction # Obtain mean prediction
mean_prediction = quantifier.quantify(test.X) mean_prediction = quantifier.predict(test.X)
mae = qp.error.mae(test.prevalence(), mean_prediction) mae = qp.error.mae(test.prevalence(), mean_prediction)
x_ax = np.arange(training.n_classes) x_ax = np.arange(training.n_classes)
ax.plot(x_ax, mean_prediction, c="salmon", linewidth=2, linestyle=":", label="Bayesian") ax.plot(x_ax, mean_prediction, c="salmon", linewidth=2, linestyle=":", label="Bayesian")

View File

@ -39,10 +39,10 @@ class MyQuantifier(BaseQuantifier):
return self return self
# in general, we would need to implement the method quantify(self, instances); this would amount to: # in general, we would need to implement the method quantify(self, instances); this would amount to:
def quantify(self, instances): def predict(self, X):
assert hasattr(self.classifier, 'predict_proba'), \ assert hasattr(self.classifier, 'predict_proba'), \
'the underlying classifier is not probabilistic! [abort]' 'the underlying classifier is not probabilistic! [abort]'
posterior_probabilities = self.classifier.predict_proba(instances) posterior_probabilities = self.classifier.predict_proba(X)
positive_probabilities = posterior_probabilities[:, 1] positive_probabilities = posterior_probabilities[:, 1]
crisp_decisions = positive_probabilities > self.alpha crisp_decisions = positive_probabilities > self.alpha
pos_prev = crisp_decisions.mean() pos_prev = crisp_decisions.mean()

View File

@ -27,7 +27,7 @@ quantifier = QuaNet(cnn_classifier, device='cuda')
quantifier.fit(train, fit_classifier=False) quantifier.fit(train, fit_classifier=False)
# prediction and evaluation # prediction and evaluation
estim_prevalence = quantifier.quantify(test.instances) estim_prevalence = quantifier.predict(test.instances)
mae = qp.error.mae(test.prevalence(), estim_prevalence) mae = qp.error.mae(test.prevalence(), estim_prevalence)
print(f'true prevalence: {F.strprev(test.prevalence())}') print(f'true prevalence: {F.strprev(test.prevalence())}')

View File

@ -14,7 +14,7 @@ from . import model_selection
from . import classification from . import classification
import os import os
__version__ = '0.1.10' __version__ = '0.1.10r'
environ = { environ = {
'SAMPLE_SIZE': None, 'SAMPLE_SIZE': None,
@ -24,7 +24,7 @@ environ = {
'PAD_INDEX': 1, 'PAD_INDEX': 1,
'SVMPERF_HOME': './svm_perf_quantification', 'SVMPERF_HOME': './svm_perf_quantification',
'N_JOBS': int(os.getenv('N_JOBS', 1)), 'N_JOBS': int(os.getenv('N_JOBS', 1)),
'DEFAULT_CLS': LogisticRegression(max_iter=3000) 'DEFAULT_CLS': LogisticRegression()
} }

View File

@ -232,11 +232,11 @@ class LabelledCollection:
:return: two instances of :class:`LabelledCollection`, the first one with `train_prop` elements, and the :return: two instances of :class:`LabelledCollection`, the first one with `train_prop` elements, and the
second one with `1-train_prop` elements second one with `1-train_prop` elements
""" """
tr_docs, te_docs, tr_labels, te_labels = train_test_split( tr_X, te_X, tr_y, te_y = train_test_split(
self.instances, self.labels, train_size=train_prop, stratify=self.labels, random_state=random_state self.instances, self.labels, train_size=train_prop, stratify=self.labels, random_state=random_state
) )
training = LabelledCollection(tr_docs, tr_labels, classes=self.classes_) training = LabelledCollection(tr_X, tr_y, classes=self.classes_)
test = LabelledCollection(te_docs, te_labels, classes=self.classes_) test = LabelledCollection(te_X, te_y, classes=self.classes_)
return training, test return training, test
def split_random(self, train_prop=0.6, random_state=None): def split_random(self, train_prop=0.6, random_state=None):

View File

@ -63,7 +63,7 @@ def prediction(
protocol_with_predictions = protocol.on_preclassified_instances(pre_classified) protocol_with_predictions = protocol.on_preclassified_instances(pre_classified)
return __prediction_helper(model.aggregate, protocol_with_predictions, verbose) return __prediction_helper(model.aggregate, protocol_with_predictions, verbose)
else: else:
return __prediction_helper(model.quantify, protocol, verbose) return __prediction_helper(model.predict, protocol, verbose)
def __prediction_helper(quantification_fn, protocol: AbstractProtocol, verbose=False): def __prediction_helper(quantification_fn, protocol: AbstractProtocol, verbose=False):

View File

@ -38,7 +38,7 @@ class QuaNetTrainer(BaseQuantifier):
>>> # train QuaNet (QuaNet is an alias to QuaNetTrainer) >>> # train QuaNet (QuaNet is an alias to QuaNetTrainer)
>>> model = QuaNet(classifier, qp.environ['SAMPLE_SIZE'], device='cuda') >>> model = QuaNet(classifier, qp.environ['SAMPLE_SIZE'], device='cuda')
>>> model.fit(dataset.training) >>> model.fit(dataset.training)
>>> estim_prevalence = model.quantify(dataset.test.instances) >>> estim_prevalence = model.predict(dataset.test.instances)
:param classifier: an object implementing `fit` (i.e., that can be trained on labelled data), :param classifier: an object implementing `fit` (i.e., that can be trained on labelled data),
`predict_proba` (i.e., that can generate posterior probabilities of unlabelled examples) and `predict_proba` (i.e., that can generate posterior probabilities of unlabelled examples) and
@ -201,9 +201,9 @@ class QuaNetTrainer(BaseQuantifier):
return prevs_estim return prevs_estim
def quantify(self, instances): def predict(self, X):
posteriors = self.classifier.predict_proba(instances) posteriors = self.classifier.predict_proba(X)
embeddings = self.classifier.transform(instances) embeddings = self.classifier.transform(X)
quant_estims = self._get_aggregative_estims(posteriors) quant_estims = self._get_aggregative_estims(posteriors)
self.quanet.eval() self.quanet.eval()
with torch.no_grad(): with torch.no_grad():

View File

@ -5,8 +5,10 @@ import numpy as np
from abstention.calibration import NoBiasVectorScaling, TempScaling, VectorScaling from abstention.calibration import NoBiasVectorScaling, TempScaling, VectorScaling
from sklearn.base import BaseEstimator from sklearn.base import BaseEstimator
from sklearn.calibration import CalibratedClassifierCV from sklearn.calibration import CalibratedClassifierCV
from sklearn.exceptions import NotFittedError
from sklearn.metrics import confusion_matrix from sklearn.metrics import confusion_matrix
from sklearn.model_selection import cross_val_predict from sklearn.model_selection import cross_val_predict, train_test_split
from sklearn.utils.validation import check_is_fitted
import quapy as qp import quapy as qp
import quapy.functional as F import quapy.functional as F
@ -14,6 +16,7 @@ from quapy.functional import get_divergence
from quapy.classification.svmperf import SVMperf from quapy.classification.svmperf import SVMperf
from quapy.data import LabelledCollection from quapy.data import LabelledCollection
from quapy.method.base import BaseQuantifier, BinaryQuantifier, OneVsAllGeneric from quapy.method.base import BaseQuantifier, BinaryQuantifier, OneVsAllGeneric
from quapy.method import _bayesian
# Abstract classes # Abstract classes
@ -35,18 +38,53 @@ class AggregativeQuantifier(BaseQuantifier, ABC):
and :meth:`aggregate`. and :meth:`aggregate`.
""" """
val_split_ = None def __init__(self, classifier: Union[None,BaseEstimator], fit_classifier:bool=True, val_split:Union[int,float,tuple,None]=5):
self.classifier = qp._get_classifier(classifier)
self.fit_classifier = fit_classifier
self.val_split = val_split
@property # basic type checks
def val_split(self): assert hasattr(self.classifier, 'fit'), \
return self.val_split_ f'the classifier does not implement "fit"'
@val_split.setter assert isinstance(fit_classifier, bool), \
def val_split(self, val_split): f'unexpected type for {fit_classifier=}; must be True or False'
if isinstance(val_split, LabelledCollection):
print('warning: setting val_split with a LabelledCollection will be inefficient in' if isinstance(val_split, int):
'model selection. Rather pass the LabelledCollection at fit time') assert val_split > 1, \
self.val_split_ = val_split (f'when {val_split=} is indicated as an integer, it represents the number of folds in a kFCV '
f'and must thus be >1')
assert fit_classifier, (f'when {val_split=} is indicated as an integer (the number of folds for kFCV) '
f'the parameter {fit_classifier=} must be True')
elif isinstance(val_split, float):
assert 0 < val_split < 1, \
(f'when {val_split=} is indicated as a float, it represents the fraction of training instances '
f'to be used for validation, and must thus be in the range (0,1)')
assert fit_classifier, (f'when {val_split=} is indicated as a float (the fraction of training instances '
f'to be used for validation), the parameter {fit_classifier=} must be True')
elif isinstance(val_split, tuple):
assert len(val_split) == 2, \
(f'when {val_split=} is indicated as a tuple, it represents the collection (X,y) on which the '
f'validation must be performed, but this seems to have different cardinality')
elif val_split is None:
pass
else:
raise ValueError(f'unexpected type for {val_split=}')
# classifier is fitted?
try:
check_is_fitted(self.classifier)
fitted = True
except NotFittedError:
fitted = False
# consistency checks: fit_classifier?
if self.fit_classifier:
if fitted:
raise RuntimeWarning(f'the classifier is already fitted, by {fit_classifier=} was requested')
else:
assert fitted, (f'{fit_classifier=} requires the classifier to be already trained, '
f'but this does not seem to be')
def _check_init_parameters(self): def _check_init_parameters(self):
""" """
@ -58,20 +96,36 @@ class AggregativeQuantifier(BaseQuantifier, ABC):
""" """
pass pass
def _check_non_empty_classes(self, data: LabelledCollection): def _check_non_empty_classes(self, y):
""" """
Asserts all classes have positive instances. Asserts all classes have positive instances.
:param data: LabelledCollection :param labels: array-like of shape `(n_instances,)` with the label for each instance
:param classes: the class labels. This is needed in order to correctly compute the prevalence vector even when
some classes have no examples.
:return: Nothing. May raise an exception. :return: Nothing. May raise an exception.
""" """
sample_prevs = data.prevalence() sample_prevs = F.prevalence_from_labels(y, self.classes_)
empty_classes = np.argwhere(sample_prevs==0).flatten() empty_classes = np.argwhere(sample_prevs == 0).flatten()
if len(empty_classes)>0: if len(empty_classes) > 0:
empty_class_names = data.classes_[empty_classes] empty_class_names = self.classes_[empty_classes]
raise ValueError(f'classes {empty_class_names} have no training examples') raise ValueError(f'classes {empty_class_names} have no training examples')
def fit(self, data: LabelledCollection, fit_classifier=True, val_split=None): def fit(self, X, y):
"""
Trains the aggregative quantifier. This comes down to training a classifier (if requested) and an
aggregation function.
:param X: array-like, the training instances
:param y: array-like, the labels
:return: self
"""
self._check_init_parameters()
classif_predictions = self.classifier_fit_predict(X, y)
self.aggregation_fit(classif_predictions)
return self
def fit_depr(self, data: LabelledCollection, fit_classifier=True, val_split=None):
""" """
Trains the aggregative quantifier. This comes down to training a classifier and an aggregation function. Trains the aggregative quantifier. This comes down to training a classifier and an aggregation function.
@ -88,94 +142,55 @@ class AggregativeQuantifier(BaseQuantifier, ABC):
:return: self :return: self
""" """
self._check_init_parameters() self._check_init_parameters()
classif_predictions = self.classifier_fit_predict(data, fit_classifier, predict_on=val_split) classif_predictions = self.classifier_fit_predict_depr(data, fit_classifier, predict_on=val_split)
self.aggregation_fit(classif_predictions, data) self.aggregation_fit_depr(classif_predictions, data)
return self return self
def classifier_fit_predict(self, data: LabelledCollection, fit_classifier=True, predict_on=None): def classifier_fit_predict(self, X, y):
""" """
Trains the classifier if requested (`fit_classifier=True`) and generate the necessary predictions to Trains the classifier if requested (`fit_classifier=True`) and generate the necessary predictions to
train the aggregation function. train the aggregation function.
:param data: a :class:`quapy.data.base.LabelledCollection` consisting of the training data :param X: array-like, the training instances
:param fit_classifier: whether to train the learner (default is True). Set to False if the :param y: array-like, the labels
learner has been trained outside the quantifier.
:param predict_on: specifies the set on which predictions need to be issued. This parameter can
be specified as None (default) to indicate no prediction is needed; a float in (0, 1) to
indicate the proportion of instances to be used for predictions (the remainder is used for
training); an integer >1 to indicate that the predictions must be generated via k-fold
cross-validation, using this integer as k; or the data sample itself on which to generate
the predictions.
""" """
assert isinstance(fit_classifier, bool), 'unexpected type for "fit_classifier", must be boolean' self._check_classifier()
self._check_classifier(adapt_if_necessary=(self._classifier_method() == 'predict_proba')) # self._check_non_empty_classes(y)
if fit_classifier:
self._check_non_empty_classes(data)
if predict_on is None:
if not fit_classifier:
predict_on = data
if isinstance(self.val_split, LabelledCollection) and self.val_split!=predict_on:
raise ValueError(f'{fit_classifier=} but a LabelledCollection was provided as val_split '
f'in __init__ that is not the same as the LabelledCollection provided in fit.')
if predict_on is None:
predict_on = self.val_split
if predict_on is None:
if fit_classifier:
self.classifier.fit(*data.Xy)
predictions = None
elif isinstance(predict_on, float):
if fit_classifier:
if not (0. < predict_on < 1.):
raise ValueError(f'proportion {predict_on=} out of range, must be in (0,1)')
train, val = data.split_stratified(train_prop=(1 - predict_on))
self.classifier.fit(*train.Xy)
predictions = LabelledCollection(self.classify(val.X), val.y, classes=data.classes_)
else:
raise ValueError(f'wrong type for predict_on: since fit_classifier=False, '
f'the set on which predictions have to be issued must be '
f'explicitly indicated')
elif isinstance(predict_on, LabelledCollection):
if fit_classifier:
self.classifier.fit(*data.Xy)
predictions = LabelledCollection(self.classify(predict_on.X), predict_on.y, classes=predict_on.classes_)
elif isinstance(predict_on, int):
if fit_classifier:
if predict_on <= 1:
raise ValueError(f'invalid value {predict_on} in fit. '
f'Specify a integer >1 for kFCV estimation.')
else:
n_jobs = self.n_jobs if hasattr(self, 'n_jobs') else qp._get_njobs(None)
predictions = cross_val_predict(
self.classifier, *data.Xy, cv=predict_on, n_jobs=n_jobs, method=self._classifier_method())
predictions = LabelledCollection(predictions, data.y, classes=data.classes_)
self.classifier.fit(*data.Xy)
else:
raise ValueError(f'wrong type for predict_on: since fit_classifier=False, '
f'the set on which predictions have to be issued must be '
f'explicitly indicated')
if isinstance(self.val_split, int):
assert self.fit_classifier, f'unexpected value for {self.fit_classifier=}'
num_folds = self.val_split
n_jobs = self.n_jobs if hasattr(self, 'n_jobs') else qp._get_njobs(None)
predictions = cross_val_predict(self.classifier, X, y, cv=num_folds, n_jobs=n_jobs, method=self._classifier_method())
yval = y
self.classifier.fit(X, y)
elif isinstance(self.val_split, float):
assert self.fit_classifier, f'unexpected value for {self.fit_classifier=}'
train_prop = 1. - self.val_split
Xtr, Xval, ytr, yval = train_test_split(X, y, train_size=train_prop, stratify=y)
self.classifier.fit(Xtr, ytr)
predictions = self.classify(Xval)
elif isinstance(self.val_split, tuple):
Xval, yval = self.val_split
if self.fit_classifier:
self.classifier.fit(X, y)
elif self.val_split is None:
if self.fit_classifier:
self.classifier.fit(X, y)
predictions, yval = None, None
else: else:
raise ValueError( raise ValueError(f'unexpected type for {self.val_split=}')
f'error: param "predict_on" ({type(predict_on)}) not understood; '
f'use either a float indicating the split proportion, or a '
f'tuple (X,y) indicating the validation partition')
return predictions return predictions, yval
@abstractmethod @abstractmethod
def aggregation_fit(self, classif_predictions: LabelledCollection, data: LabelledCollection): def aggregation_fit(self, classif_predictions, **kwargs):
""" """
Trains the aggregation function. Trains the aggregation function.
:param classif_predictions: a :class:`quapy.data.base.LabelledCollection` containing, :param classif_predictions: the classification predictions; whatever the method
as instances, the predictions issued by the classifier and, as labels, the true labels :meth:`classify` returns
:param data: a :class:`quapy.data.base.LabelledCollection` consisting of the training data
""" """
... ...
@ -197,16 +212,16 @@ class AggregativeQuantifier(BaseQuantifier, ABC):
""" """
self.classifier_ = classifier self.classifier_ = classifier
def classify(self, instances): def classify(self, X):
""" """
Provides the label predictions for the given instances. The predictions should respect the format expected by Provides the label predictions for the given instances. The predictions should respect the format expected by
:meth:`aggregate`, e.g., posterior probabilities for probabilistic quantifiers, or crisp predictions for :meth:`aggregate`, e.g., posterior probabilities for probabilistic quantifiers, or crisp predictions for
non-probabilistic quantifiers. The default one is "decision_function". non-probabilistic quantifiers. The default one is "decision_function".
:param instances: array-like of shape `(n_instances, n_features,)` :param X: array-like of shape `(n_instances, n_features,)`
:return: np.ndarray of shape `(n_instances,)` with label predictions :return: np.ndarray of shape `(n_instances,)` with label predictions
""" """
return getattr(self.classifier, self._classifier_method())(instances) return getattr(self.classifier, self._classifier_method())(X)
def _classifier_method(self): def _classifier_method(self):
""" """
@ -221,26 +236,26 @@ class AggregativeQuantifier(BaseQuantifier, ABC):
Guarantees that the underlying classifier implements the method required for issuing predictions, i.e., Guarantees that the underlying classifier implements the method required for issuing predictions, i.e.,
the method indicated by the :meth:`_classifier_method` the method indicated by the :meth:`_classifier_method`
:param adapt_if_necessary: if True, the method will try to comply with the required specifications :param adapt_if_necessary: unused unless overriden
""" """
assert hasattr(self.classifier, self._classifier_method()), \ assert hasattr(self.classifier, self._classifier_method()), \
f"the method does not implement the required {self._classifier_method()} method" f"the method does not implement the required {self._classifier_method()} method"
def quantify(self, instances): def predict(self, X):
""" """
Generate class prevalence estimates for the sample's instances by aggregating the label predictions generated Generate class prevalence estimates for the sample's instances by aggregating the label predictions generated
by the classifier. by the classifier.
:param instances: array-like :param X: array-like
:return: `np.ndarray` of shape `(n_classes)` with class prevalence estimates. :return: `np.ndarray` of shape `(n_classes)` with class prevalence estimates.
""" """
classif_predictions = self.classify(instances) classif_predictions = self.classify(X)
return self.aggregate(classif_predictions) return self.aggregate(classif_predictions)
@abstractmethod @abstractmethod
def aggregate(self, classif_predictions: np.ndarray): def aggregate(self, classif_predictions: np.ndarray):
""" """
Implements the aggregation of label predictions. Implements the aggregation of the classifier predictions.
:param classif_predictions: `np.ndarray` of label predictions :param classif_predictions: `np.ndarray` of label predictions
:return: `np.ndarray` of shape `(n_classes,)` with class prevalence estimates. :return: `np.ndarray` of shape `(n_classes,)` with class prevalence estimates.
@ -324,9 +339,9 @@ class BinaryAggregativeQuantifier(AggregativeQuantifier, BinaryQuantifier):
def neg_label(self): def neg_label(self):
return self.classifier.classes_[0] return self.classifier.classes_[0]
def fit(self, data: LabelledCollection, fit_classifier=True, val_split=None): def fit(self, X, y):
self._check_binary(data, self.__class__.__name__) self._check_binary(y, self.__class__.__name__)
return super().fit(data, fit_classifier, val_split) return super().fit(X, y)
# Methods # Methods
@ -338,16 +353,14 @@ class CC(AggregativeCrispQuantifier):
:param classifier: a sklearn's Estimator that generates a classifier :param classifier: a sklearn's Estimator that generates a classifier
""" """
def __init__(self, classifier: BaseEstimator = None, fit_classifier: bool = True):
super().__init__(classifier, fit_classifier, val_split=None)
def __init__(self, classifier: BaseEstimator=None): def aggregation_fit(self, classif_predictions):
self.classifier = qp._get_classifier(classifier)
def aggregation_fit(self, classif_predictions: LabelledCollection, data: LabelledCollection):
""" """
Nothing to do here! Nothing to do here!
:param classif_predictions: not used :param classif_predictions: not used
:param data: not used
""" """
pass pass
@ -369,15 +382,14 @@ class PCC(AggregativeSoftQuantifier):
:param classifier: a sklearn's Estimator that generates a classifier :param classifier: a sklearn's Estimator that generates a classifier
""" """
def __init__(self, classifier: BaseEstimator=None): def __init__(self, classifier: BaseEstimator = None, fit_classifier: bool = True):
self.classifier = qp._get_classifier(classifier) super().__init__(classifier, fit_classifier, val_split=None)
def aggregation_fit(self, classif_predictions: LabelledCollection, data: LabelledCollection): def aggregation_fit(self, classif_predictions):
""" """
Nothing to do here! Nothing to do here!
:param classif_predictions: not used :param classif_predictions: not used
:param data: not used
""" """
pass pass
@ -430,17 +442,18 @@ class ACC(AggregativeCrispQuantifier):
:param n_jobs: number of parallel workers :param n_jobs: number of parallel workers
""" """
def __init__( def __init__(
self, self,
classifier: BaseEstimator=None, classifier: BaseEstimator = None,
fit_classifier=True,
val_split=5, val_split=5,
solver: Literal['minimize', 'exact', 'exact-raise', 'exact-cc'] = 'minimize', solver: Literal['minimize', 'exact', 'exact-raise', 'exact-cc'] = 'minimize',
method: Literal['inversion', 'invariant-ratio'] = 'inversion', method: Literal['inversion', 'invariant-ratio'] = 'inversion',
norm: Literal['clip', 'mapsimplex', 'condsoftmax'] = 'clip', norm: Literal['clip', 'mapsimplex', 'condsoftmax'] = 'clip',
n_jobs=None, n_jobs=None,
): ):
self.classifier = qp._get_classifier(classifier) super().__init__(classifier, fit_classifier, val_split)
self.val_split = val_split
self.n_jobs = qp._get_njobs(n_jobs) self.n_jobs = qp._get_njobs(n_jobs)
self.solver = solver self.solver = solver
self.method = method self.method = method
@ -451,24 +464,25 @@ class ACC(AggregativeCrispQuantifier):
NORMALIZATIONS = ['clip', 'mapsimplex', 'condsoftmax', None] NORMALIZATIONS = ['clip', 'mapsimplex', 'condsoftmax', None]
@classmethod @classmethod
def newInvariantRatioEstimation(cls, classifier: BaseEstimator, val_split=5, n_jobs=None): def newInvariantRatioEstimation(cls, classifier: BaseEstimator, fit_classifier=True, val_split=5, n_jobs=None):
""" """
Constructs a quantifier that implements the Invariant Ratio Estimator of Constructs a quantifier that implements the Invariant Ratio Estimator of
`Vaz et al. 2018 <https://jmlr.org/papers/v20/18-456.html>`_. This amounts `Vaz et al. 2018 <https://jmlr.org/papers/v20/18-456.html>`_. This amounts
to setting method to 'invariant-ratio' and clipping to 'project'. to setting method to 'invariant-ratio' and clipping to 'project'.
:param classifier: a sklearn's Estimator that generates a classifier :param classifier: a sklearn's Estimator that generates a classifier
:param fit_classifier: bool, whether to fit the classifier or not
:param val_split: specifies the data used for generating classifier predictions. This specification :param val_split: specifies the data used for generating classifier predictions. This specification
can be made as float in (0, 1) indicating the proportion of stratified held-out validation set to can be made as float in (0, 1) indicating the proportion of stratified held-out validation set to
be extracted from the training set; or as an integer (default 5), indicating that the predictions be extracted from the training set; or as an integer (default 5), indicating that the predictions
are to be generated in a `k`-fold cross-validation manner (with this integer indicating the value are to be generated in a `k`-fold cross-validation manner (with this integer indicating the value
for `k`); or as a collection defining the specific set of data to use for validation. for `k`); or as a collection defining the specific set of data to use for validation.
Alternatively, this set can be specified at fit time by indicating the exact set of data Alternatively, this set can be specified at fit time by indicating the exact set of data
on which the predictions are to be generated. on which the predictions are to be generated.
:param n_jobs: number of parallel workers :param n_jobs: number of parallel workers
:return: an instance of ACC configured so that it implements the Invariant Ratio Estimator :return: an instance of ACC configured so that it implements the Invariant Ratio Estimator
""" """
return ACC(classifier, val_split=val_split, method='invariant-ratio', norm='mapsimplex', n_jobs=n_jobs) return ACC(classifier, fit_classifier=fit_classifier, val_split=val_split, method='invariant-ratio', norm='mapsimplex', n_jobs=n_jobs)
def _check_init_parameters(self): def _check_init_parameters(self):
if self.solver not in ACC.SOLVERS: if self.solver not in ACC.SOLVERS:
@ -478,7 +492,7 @@ class ACC(AggregativeCrispQuantifier):
if self.norm not in ACC.NORMALIZATIONS: if self.norm not in ACC.NORMALIZATIONS:
raise ValueError(f"unknown normalization; valid ones are {ACC.NORMALIZATIONS}") raise ValueError(f"unknown normalization; valid ones are {ACC.NORMALIZATIONS}")
def aggregation_fit(self, classif_predictions: LabelledCollection, data: LabelledCollection): def aggregation_fit(self, classif_predictions):
""" """
Estimates the misclassification rates. Estimates the misclassification rates.
@ -486,8 +500,8 @@ class ACC(AggregativeCrispQuantifier):
as instances, the label predictions issued by the classifier and, as labels, the true labels as instances, the label predictions issued by the classifier and, as labels, the true labels
:param data: a :class:`quapy.data.base.LabelledCollection` consisting of the training data :param data: a :class:`quapy.data.base.LabelledCollection` consisting of the training data
""" """
pred_labels, true_labels = classif_predictions.Xy pred_labels, true_labels = classif_predictions
self.cc = CC(self.classifier) self.cc = CC(self.classifier, fit_classifier=False)
self.Pte_cond_estim_ = ACC.getPteCondEstim(self.classifier.classes_, true_labels, pred_labels) self.Pte_cond_estim_ = ACC.getPteCondEstim(self.classifier.classes_, true_labels, pred_labels)
@classmethod @classmethod
@ -529,6 +543,8 @@ class PACC(AggregativeSoftQuantifier):
:param classifier: a sklearn's Estimator that generates a classifier :param classifier: a sklearn's Estimator that generates a classifier
:param fit_classifier: bool, whether to fit the classifier or not
:param val_split: specifies the data used for generating classifier predictions. This specification :param val_split: specifies the data used for generating classifier predictions. This specification
can be made as float in (0, 1) indicating the proportion of stratified held-out validation set to can be made as float in (0, 1) indicating the proportion of stratified held-out validation set to
be extracted from the training set; or as an integer (default 5), indicating that the predictions be extracted from the training set; or as an integer (default 5), indicating that the predictions
@ -565,17 +581,18 @@ class PACC(AggregativeSoftQuantifier):
:param n_jobs: number of parallel workers :param n_jobs: number of parallel workers
""" """
def __init__( def __init__(
self, self,
classifier: BaseEstimator=None, classifier: BaseEstimator = None,
fit_classifier=True,
val_split=5, val_split=5,
solver: Literal['minimize', 'exact', 'exact-raise', 'exact-cc'] = 'minimize', solver: Literal['minimize', 'exact', 'exact-raise', 'exact-cc'] = 'minimize',
method: Literal['inversion', 'invariant-ratio'] = 'inversion', method: Literal['inversion', 'invariant-ratio'] = 'inversion',
norm: Literal['clip', 'mapsimplex', 'condsoftmax'] = 'clip', norm: Literal['clip', 'mapsimplex', 'condsoftmax'] = 'clip',
n_jobs=None n_jobs=None
): ):
self.classifier = qp._get_classifier(classifier) super().__init__(classifier, fit_classifier, val_split)
self.val_split = val_split
self.n_jobs = qp._get_njobs(n_jobs) self.n_jobs = qp._get_njobs(n_jobs)
self.solver = solver self.solver = solver
self.method = method self.method = method
@ -589,7 +606,7 @@ class PACC(AggregativeSoftQuantifier):
if self.norm not in ACC.NORMALIZATIONS: if self.norm not in ACC.NORMALIZATIONS:
raise ValueError(f"unknown normalization; valid ones are {ACC.NORMALIZATIONS}") raise ValueError(f"unknown normalization; valid ones are {ACC.NORMALIZATIONS}")
def aggregation_fit(self, classif_predictions: LabelledCollection, data: LabelledCollection): def aggregation_fit(self, classif_predictions):
""" """
Estimates the misclassification rates Estimates the misclassification rates
@ -597,8 +614,8 @@ class PACC(AggregativeSoftQuantifier):
as instances, the posterior probabilities issued by the classifier and, as labels, the true labels as instances, the posterior probabilities issued by the classifier and, as labels, the true labels
:param data: a :class:`quapy.data.base.LabelledCollection` consisting of the training data :param data: a :class:`quapy.data.base.LabelledCollection` consisting of the training data
""" """
posteriors, true_labels = classif_predictions.Xy posteriors, true_labels = classif_predictions
self.pcc = PCC(self.classifier) self.pcc = PCC(self.classifier, fit_classifier=False)
self.Pte_cond_estim_ = PACC.getPteCondEstim(self.classifier.classes_, true_labels, posteriors) self.Pte_cond_estim_ = PACC.getPteCondEstim(self.classifier.classes_, true_labels, posteriors)
def aggregate(self, classif_posteriors): def aggregate(self, classif_posteriors):
@ -640,6 +657,7 @@ class EMQ(AggregativeSoftQuantifier):
and to recalibrate the posterior probabilities of the classifier. and to recalibrate the posterior probabilities of the classifier.
:param classifier: a sklearn's Estimator that generates a classifier :param classifier: a sklearn's Estimator that generates a classifier
:param fit_classifier: bool, whether to fit the classifier or not
:param val_split: specifies the data used for generating classifier predictions. This specification :param val_split: specifies the data used for generating classifier predictions. This specification
can be made as float in (0, 1) indicating the proportion of stratified held-out validation set to can be made as float in (0, 1) indicating the proportion of stratified held-out validation set to
be extracted from the training set; or as an integer, indicating that the predictions be extracted from the training set; or as an integer, indicating that the predictions
@ -663,15 +681,15 @@ class EMQ(AggregativeSoftQuantifier):
MAX_ITER = 1000 MAX_ITER = 1000
EPSILON = 1e-4 EPSILON = 1e-4
def __init__(self, classifier: BaseEstimator=None, val_split=None, exact_train_prev=True, recalib=None, n_jobs=None): def __init__(self, classifier: BaseEstimator = None, fit_classifier=True, val_split=None, exact_train_prev=True, recalib=None,
self.classifier = qp._get_classifier(classifier) n_jobs=None):
self.val_split = val_split super().__init__(classifier, fit_classifier, val_split)
self.exact_train_prev = exact_train_prev self.exact_train_prev = exact_train_prev
self.recalib = recalib self.recalib = recalib
self.n_jobs = n_jobs self.n_jobs = n_jobs
@classmethod @classmethod
def EMQ_BCTS(cls, classifier: BaseEstimator, n_jobs=None): def EMQ_BCTS(cls, classifier: BaseEstimator, fit_classifier=True, val_split=5, n_jobs=None):
""" """
Constructs an instance of EMQ using the best configuration found in the `Alexandari et al. paper Constructs an instance of EMQ using the best configuration found in the `Alexandari et al. paper
<http://proceedings.mlr.press/v119/alexandari20a.html>`_, i.e., one that relies on Bias-Corrected Temperature <http://proceedings.mlr.press/v119/alexandari20a.html>`_, i.e., one that relies on Bias-Corrected Temperature
@ -682,46 +700,46 @@ class EMQ(AggregativeSoftQuantifier):
:param n_jobs: number of parallel workers. :param n_jobs: number of parallel workers.
:return: An instance of EMQ with BCTS :return: An instance of EMQ with BCTS
""" """
return EMQ(classifier, val_split=5, exact_train_prev=False, recalib='bcts', n_jobs=n_jobs) return EMQ(classifier, fit_classifier=fit_classifier, val_split=val_split, exact_train_prev=False, recalib='bcts', n_jobs=n_jobs)
def _check_init_parameters(self): def _check_init_parameters(self):
if self.val_split is not None: if self.val_split is not None:
if self.exact_train_prev and self.recalib is None: if self.exact_train_prev and self.recalib is None:
raise RuntimeWarning(f'The parameter {self.val_split=} was specified for EMQ, while the parameters ' raise RuntimeWarning(f'The parameter {self.val_split=} was specified for EMQ, while the parameters '
f'{self.exact_train_prev=} and {self.recalib=}. This has no effect and causes an unnecessary ' f'{self.exact_train_prev=} and {self.recalib=}. This has no effect and causes an unnecessary '
f'overload.') f'overload.')
else: else:
if self.recalib is not None: if self.recalib is not None:
print(f'[warning] The parameter {self.recalib=} requires the val_split be different from None. ' print(f'[warning] The parameter {self.recalib=} requires the val_split be different from None. '
f'This parameter will be set to 5. To avoid this warning, set this value to a float value ' f'This parameter will be set to 5. To avoid this warning, set this value to a float value '
f'indicating the proportion of training data to be used as validation, or to an integer ' f'indicating the proportion of training data to be used as validation, or to an integer '
f'indicating the number of folds for kFCV.') f'indicating the number of folds for kFCV.')
self.val_split=5 self.val_split = 5
def classify(self, instances): def classify(self, X):
""" """
Provides the posterior probabilities for the given instances. If the classifier was required Provides the posterior probabilities for the given instances. If the classifier was required
to be recalibrated, then these posteriors are recalibrated accordingly. to be recalibrated, then these posteriors are recalibrated accordingly.
:param instances: array-like of shape `(n_instances, n_dimensions,)` :param X: array-like of shape `(n_instances, n_dimensions,)`
:return: np.ndarray of shape `(n_instances, n_classes,)` with posterior probabilities :return: np.ndarray of shape `(n_instances, n_classes,)` with posterior probabilities
""" """
posteriors = self.classifier.predict_proba(instances) posteriors = self.classifier.predict_proba(X)
if hasattr(self, 'calibration_function') and self.calibration_function is not None: if hasattr(self, 'calibration_function') and self.calibration_function is not None:
posteriors = self.calibration_function(posteriors) posteriors = self.calibration_function(posteriors)
return posteriors return posteriors
def aggregation_fit(self, classif_predictions: LabelledCollection, data: LabelledCollection): def aggregation_fit(self, classif_predictions):
""" """
Trains the aggregation function of EMQ. This comes down to recalibrating the posterior probabilities Trains the aggregation function of EMQ. This comes down to recalibrating the posterior probabilities
ir requested. ir requested.
:param classif_predictions: a :class:`quapy.data.base.LabelledCollection` containing, :param classif_predictions: a :class:`quapy.data.base.LabelledCollection` containing,
as instances, the posterior probabilities issued by the classifier and, as labels, the true labels as instances, the posterior probabilities issued by the classifier and, as labels, the true labels
:param data: a :class:`quapy.data.base.LabelledCollection` consisting of the training data
""" """
P, y = classif_predictions
n_classes = len(self.classes_)
if self.recalib is not None: if self.recalib is not None:
P, y = classif_predictions.Xy
if self.recalib == 'nbvs': if self.recalib == 'nbvs':
calibrator = NoBiasVectorScaling() calibrator = NoBiasVectorScaling()
elif self.recalib == 'bcts': elif self.recalib == 'bcts':
@ -735,11 +753,11 @@ class EMQ(AggregativeSoftQuantifier):
'"nbvs", "bcts", "ts", and "vs".') '"nbvs", "bcts", "ts", and "vs".')
if not np.issubdtype(y.dtype, np.number): if not np.issubdtype(y.dtype, np.number):
y = np.searchsorted(data.classes_, y) y = np.searchsorted(self.classes_, y)
self.calibration_function = calibrator(P, np.eye(data.n_classes)[y], posterior_supplied=True) self.calibration_function = calibrator(P, np.eye(n_classes)[y], posterior_supplied=True)
if self.exact_train_prev: if self.exact_train_prev:
self.train_prevalence = data.prevalence() self.train_prevalence = F.prevalence_from_labels(y, self.classes_)
else: else:
train_posteriors = classif_predictions.X train_posteriors = classif_predictions.X
if self.recalib is not None: if self.recalib is not None:
@ -806,6 +824,101 @@ class EMQ(AggregativeSoftQuantifier):
return qs, ps return qs, ps
class BayesianCC(AggregativeCrispQuantifier):
"""
`Bayesian quantification <https://arxiv.org/abs/2302.09159>`_ method,
which is a variant of :class:`ACC` that calculates the posterior probability distribution
over the prevalence vectors, rather than providing a point estimate obtained
by matrix inversion.
Can be used to diagnose degeneracy in the predictions visible when the confusion
matrix has high condition number or to quantify uncertainty around the point estimate.
This method relies on extra dependencies, which have to be installed via:
`$ pip install quapy[bayes]`
:param classifier: a sklearn's Estimator that generates a classifier
:param val_split: a float in (0, 1) indicating the proportion of the training data to be used,
as a stratified held-out validation set, for generating classifier predictions.
:param num_warmup: number of warmup iterations for the MCMC sampler (default 500)
:param num_samples: number of samples to draw from the posterior (default 1000)
:param mcmc_seed: random seed for the MCMC sampler (default 0)
"""
def __init__(self,
classifier: BaseEstimator = None,
val_split: float = 0.75,
num_warmup: int = 500,
num_samples: int = 1_000,
mcmc_seed: int = 0):
if num_warmup <= 0:
raise ValueError(f'parameter {num_warmup=} must be a positive integer')
if num_samples <= 0:
raise ValueError(f'parameter {num_samples=} must be a positive integer')
if (not isinstance(val_split, float)) or val_split <= 0 or val_split >= 1:
raise ValueError(f'val_split must be a float in (0, 1), got {val_split}')
if _bayesian.DEPENDENCIES_INSTALLED is False:
raise ImportError("Auxiliary dependencies are required. Run `$ pip install quapy[bayes]` to install them.")
self.classifier = qp._get_classifier(classifier)
self.val_split = val_split
self.num_warmup = num_warmup
self.num_samples = num_samples
self.mcmc_seed = mcmc_seed
# Array of shape (n_classes, n_predicted_classes,) where entry (y, c) is the number of instances
# labeled as class y and predicted as class c.
# By default, this array is set to None and later defined as part of the `aggregation_fit` phase
self._n_and_c_labeled = None
# Dictionary with posterior samples, set when `aggregate` is provided.
self._samples = None
def aggregation_fit(self, classif_predictions: LabelledCollection, data: LabelledCollection):
"""
Estimates the misclassification rates.
:param classif_predictions: a :class:`quapy.data.base.LabelledCollection` containing,
as instances, the label predictions issued by the classifier and, as labels, the true labels
:param data: a :class:`quapy.data.base.LabelledCollection` consisting of the training data
"""
pred_labels, true_labels = classif_predictions.Xy
self._n_and_c_labeled = confusion_matrix(y_true=true_labels, y_pred=pred_labels,
labels=self.classifier.classes_)
def sample_from_posterior(self, classif_predictions):
if self._n_and_c_labeled is None:
raise ValueError("aggregation_fit must be called before sample_from_posterior")
n_c_unlabeled = F.counts_from_labels(classif_predictions, self.classifier.classes_)
self._samples = _bayesian.sample_posterior(
n_c_unlabeled=n_c_unlabeled,
n_y_and_c_labeled=self._n_and_c_labeled,
num_warmup=self.num_warmup,
num_samples=self.num_samples,
seed=self.mcmc_seed,
)
return self._samples
def get_prevalence_samples(self):
if self._samples is None:
raise ValueError("sample_from_posterior must be called before get_prevalence_samples")
return self._samples[_bayesian.P_TEST_Y]
def get_conditional_probability_samples(self):
if self._samples is None:
raise ValueError("sample_from_posterior must be called before get_conditional_probability_samples")
return self._samples[_bayesian.P_C_COND_Y]
def aggregate(self, classif_predictions):
samples = self.sample_from_posterior(classif_predictions)[_bayesian.P_TEST_Y]
return np.asarray(samples.mean(axis=0), dtype=float)
class HDy(AggregativeSoftQuantifier, BinaryAggregativeQuantifier): class HDy(AggregativeSoftQuantifier, BinaryAggregativeQuantifier):
""" """
`Hellinger Distance y <https://www.sciencedirect.com/science/article/pii/S0020025512004069>`_ (HDy). `Hellinger Distance y <https://www.sciencedirect.com/science/article/pii/S0020025512004069>`_ (HDy).
@ -821,7 +934,7 @@ class HDy(AggregativeSoftQuantifier, BinaryAggregativeQuantifier):
validation distribution, or a :class:`quapy.data.base.LabelledCollection` (the split itself), or an integer indicating the number of folds (default 5).. validation distribution, or a :class:`quapy.data.base.LabelledCollection` (the split itself), or an integer indicating the number of folds (default 5)..
""" """
def __init__(self, classifier: BaseEstimator=None, val_split=5): def __init__(self, classifier: BaseEstimator = None, val_split=5):
self.classifier = qp._get_classifier(classifier) self.classifier = qp._get_classifier(classifier)
self.val_split = val_split self.val_split = val_split
@ -897,7 +1010,8 @@ class DyS(AggregativeSoftQuantifier, BinaryAggregativeQuantifier):
:param n_jobs: number of parallel workers. :param n_jobs: number of parallel workers.
""" """
def __init__(self, classifier: BaseEstimator=None, val_split=5, n_bins=8, divergence: Union[str, Callable]= 'HD', tol=1e-05, n_jobs=None): def __init__(self, classifier: BaseEstimator = None, val_split=5, n_bins=8, divergence: Union[str, Callable] = 'HD',
tol=1e-05, n_jobs=None):
self.classifier = qp._get_classifier(classifier) self.classifier = qp._get_classifier(classifier)
self.val_split = val_split self.val_split = val_split
self.tol = tol self.tol = tol
@ -962,7 +1076,7 @@ class SMM(AggregativeSoftQuantifier, BinaryAggregativeQuantifier):
validation distribution, or a :class:`quapy.data.base.LabelledCollection` (the split itself), or an integer indicating the number of folds (default 5).. validation distribution, or a :class:`quapy.data.base.LabelledCollection` (the split itself), or an integer indicating the number of folds (default 5)..
""" """
def __init__(self, classifier: BaseEstimator=None, val_split=5): def __init__(self, classifier: BaseEstimator = None, val_split=5):
self.classifier = qp._get_classifier(classifier) self.classifier = qp._get_classifier(classifier)
self.val_split = val_split self.val_split = val_split
@ -986,7 +1100,7 @@ class SMM(AggregativeSoftQuantifier, BinaryAggregativeQuantifier):
Px = classif_posteriors[:, self.pos_label] # takes only the P(y=+1|x) Px = classif_posteriors[:, self.pos_label] # takes only the P(y=+1|x)
Px_mean = np.mean(Px) Px_mean = np.mean(Px)
class1_prev = (Px_mean - self.Pxy0_mean)/(self.Pxy1_mean - self.Pxy0_mean) class1_prev = (Px_mean - self.Pxy0_mean) / (self.Pxy1_mean - self.Pxy0_mean)
return F.as_binary_prevalence(class1_prev, clip_if_necessary=True) return F.as_binary_prevalence(class1_prev, clip_if_necessary=True)
@ -1011,7 +1125,7 @@ class DMy(AggregativeSoftQuantifier):
:param n_jobs: number of parallel workers (default None) :param n_jobs: number of parallel workers (default None)
""" """
def __init__(self, classifier: BaseEstimator=None, val_split=5, nbins=8, divergence: Union[str, Callable]='HD', def __init__(self, classifier: BaseEstimator = None, val_split=5, nbins=8, divergence: Union[str, Callable] = 'HD',
cdf=False, search='optim_minimize', n_jobs=None): cdf=False, search='optim_minimize', n_jobs=None):
self.classifier = qp._get_classifier(classifier) self.classifier = qp._get_classifier(classifier)
self.val_split = val_split self.val_split = val_split
@ -1040,7 +1154,7 @@ class DMy(AggregativeSoftQuantifier):
histograms.append(hist) histograms.append(hist)
counts = np.vstack(histograms) counts = np.vstack(histograms)
distributions = counts/counts.sum(axis=1)[:,np.newaxis] distributions = counts / counts.sum(axis=1)[:, np.newaxis]
if self.cdf: if self.cdf:
distributions = np.cumsum(distributions, axis=1) distributions = np.cumsum(distributions, axis=1)
return distributions return distributions
@ -1064,7 +1178,7 @@ class DMy(AggregativeSoftQuantifier):
self.validation_distribution = qp.util.parallel( self.validation_distribution = qp.util.parallel(
func=self._get_distributions, func=self._get_distributions,
args=[posteriors[true_labels==cat] for cat in range(n_classes)], args=[posteriors[true_labels == cat] for cat in range(n_classes)],
n_jobs=self.n_jobs, n_jobs=self.n_jobs,
backend='threading' backend='threading'
) )
@ -1083,16 +1197,16 @@ class DMy(AggregativeSoftQuantifier):
test_distribution = self._get_distributions(posteriors) test_distribution = self._get_distributions(posteriors)
divergence = get_divergence(self.divergence) divergence = get_divergence(self.divergence)
n_classes, n_channels, nbins = self.validation_distribution.shape n_classes, n_channels, nbins = self.validation_distribution.shape
def loss(prev): def loss(prev):
prev = np.expand_dims(prev, axis=0) prev = np.expand_dims(prev, axis=0)
mixture_distribution = (prev @ self.validation_distribution.reshape(n_classes,-1)).reshape(n_channels, -1) mixture_distribution = (prev @ self.validation_distribution.reshape(n_classes, -1)).reshape(n_channels, -1)
divs = [divergence(test_distribution[ch], mixture_distribution[ch]) for ch in range(n_channels)] divs = [divergence(test_distribution[ch], mixture_distribution[ch]) for ch in range(n_channels)]
return np.mean(divs) return np.mean(divs)
return F.argmin_prevalence(loss, n_classes, method=self.search) return F.argmin_prevalence(loss, n_classes, method=self.search)
def newELM(svmperf_base=None, loss='01', C=1): def newELM(svmperf_base=None, loss='01', C=1):
""" """
Explicit Loss Minimization (ELM) quantifiers. Explicit Loss Minimization (ELM) quantifiers.
@ -1145,6 +1259,7 @@ def newSVMQ(svmperf_base=None, C=1):
""" """
return newELM(svmperf_base, loss='q', C=C) return newELM(svmperf_base, loss='q', C=C)
def newSVMKLD(svmperf_base=None, C=1): def newSVMKLD(svmperf_base=None, C=1):
""" """
SVM(KLD) is an Explicit Loss Minimization (ELM) quantifier set to optimize for the Kullback-Leibler Divergence SVM(KLD) is an Explicit Loss Minimization (ELM) quantifier set to optimize for the Kullback-Leibler Divergence
@ -1195,6 +1310,7 @@ def newSVMKLD(svmperf_base=None, C=1):
""" """
return newELM(svmperf_base, loss='nkld', C=C) return newELM(svmperf_base, loss='nkld', C=C)
def newSVMAE(svmperf_base=None, C=1): def newSVMAE(svmperf_base=None, C=1):
""" """
SVM(KLD) is an Explicit Loss Minimization (ELM) quantifier set to optimize for the Absolute Error as first used by SVM(KLD) is an Explicit Loss Minimization (ELM) quantifier set to optimize for the Absolute Error as first used by
@ -1219,6 +1335,7 @@ def newSVMAE(svmperf_base=None, C=1):
""" """
return newELM(svmperf_base, loss='mae', C=C) return newELM(svmperf_base, loss='mae', C=C)
def newSVMRAE(svmperf_base=None, C=1): def newSVMRAE(svmperf_base=None, C=1):
""" """
SVM(KLD) is an Explicit Loss Minimization (ELM) quantifier set to optimize for the Relative Absolute Error as first SVM(KLD) is an Explicit Loss Minimization (ELM) quantifier set to optimize for the Relative Absolute Error as first
@ -1269,7 +1386,7 @@ class OneVsAllAggregative(OneVsAllGeneric, AggregativeQuantifier):
self.n_jobs = qp._get_njobs(n_jobs) self.n_jobs = qp._get_njobs(n_jobs)
self.parallel_backend = parallel_backend self.parallel_backend = parallel_backend
def classify(self, instances): def classify(self, X):
""" """
If the base quantifier is not probabilistic, returns a matrix of shape `(n,m,)` with `n` the number of If the base quantifier is not probabilistic, returns a matrix of shape `(n,m,)` with `n` the number of
instances and `m` the number of classes. The entry `(i,j)` is a binary value indicating whether instance instances and `m` the number of classes. The entry `(i,j)` is a binary value indicating whether instance
@ -1280,11 +1397,11 @@ class OneVsAllAggregative(OneVsAllGeneric, AggregativeQuantifier):
posterior probability that instance `i` belongs (resp. does not belong) to class `j`. The posterior posterior probability that instance `i` belongs (resp. does not belong) to class `j`. The posterior
probabilities are independent of each other, meaning that, in general, they do not sum up to one. probabilities are independent of each other, meaning that, in general, they do not sum up to one.
:param instances: array-like :param X: array-like
:return: `np.ndarray` :return: `np.ndarray`
""" """
classif_predictions = self._parallel(self._delayed_binary_classification, instances) classif_predictions = self._parallel(self._delayed_binary_classification, X)
if isinstance(self.binary_quantifier, AggregativeSoftQuantifier): if isinstance(self.binary_quantifier, AggregativeSoftQuantifier):
return np.swapaxes(classif_predictions, 0, 1) return np.swapaxes(classif_predictions, 0, 1)
else: else:
@ -1314,6 +1431,7 @@ class AggregativeMedianEstimator(BinaryQuantifier):
:param param_grid: the grid or parameters towards which the median will be computed :param param_grid: the grid or parameters towards which the median will be computed
:param n_jobs: number of parllel workes :param n_jobs: number of parllel workes
""" """
def __init__(self, base_quantifier: AggregativeQuantifier, param_grid: dict, random_state=None, n_jobs=None): def __init__(self, base_quantifier: AggregativeQuantifier, param_grid: dict, random_state=None, n_jobs=None):
self.base_quantifier = base_quantifier self.base_quantifier = base_quantifier
self.param_grid = param_grid self.param_grid = param_grid
@ -1350,7 +1468,6 @@ class AggregativeMedianEstimator(BinaryQuantifier):
model.aggregation_fit(predictions, training) model.aggregation_fit(predictions, training)
return model return model
def fit(self, training: LabelledCollection, **kwargs): def fit(self, training: LabelledCollection, **kwargs):
import itertools import itertools
@ -1407,28 +1524,27 @@ class AggregativeMedianEstimator(BinaryQuantifier):
return np.median(prev_preds, axis=0) return np.median(prev_preds, axis=0)
#--------------------------------------------------------------- # ---------------------------------------------------------------
# imports # imports
#--------------------------------------------------------------- # ---------------------------------------------------------------
from . import _threshold_optim from . import _threshold_optim
T50 = _threshold_optim.T50 T50 = _threshold_optim.T50
MAX = _threshold_optim.MAX MAX = _threshold_optim.MAX
X = _threshold_optim.X X = _threshold_optim.X
MS = _threshold_optim.MS MS = _threshold_optim.MS
MS2 = _threshold_optim.MS2 MS2 = _threshold_optim.MS2
from . import _kdey from . import _kdey
KDEyML = _kdey.KDEyML KDEyML = _kdey.KDEyML
KDEyHD = _kdey.KDEyHD KDEyHD = _kdey.KDEyHD
KDEyCS = _kdey.KDEyCS KDEyCS = _kdey.KDEyCS
#--------------------------------------------------------------- # ---------------------------------------------------------------
# aliases # aliases
#--------------------------------------------------------------- # ---------------------------------------------------------------
ClassifyAndCount = CC ClassifyAndCount = CC
AdjustedClassifyAndCount = ACC AdjustedClassifyAndCount = ACC

View File

@ -14,26 +14,36 @@ import numpy as np
class BaseQuantifier(BaseEstimator): class BaseQuantifier(BaseEstimator):
""" """
Abstract Quantifier. A quantifier is defined as an object of a class that implements the method :meth:`fit` on Abstract Quantifier. A quantifier is defined as an object of a class that implements the method :meth:`fit` on
:class:`quapy.data.base.LabelledCollection`, the method :meth:`quantify`, and the :meth:`set_params` and a pair X, y, the method :meth:`predict`, and the :meth:`set_params` and
:meth:`get_params` for model selection (see :meth:`quapy.model_selection.GridSearchQ`) :meth:`get_params` for model selection (see :meth:`quapy.model_selection.GridSearchQ`)
""" """
@abstractmethod @abstractmethod
def fit(self, data: LabelledCollection): def fit(self, X, y):
""" """
Trains a quantifier. Generates a quantifier.
:param data: a :class:`quapy.data.base.LabelledCollection` consisting of the training data :param X: array-like, the training instances
:param y: array-like, the labels
:return: self :return: self
""" """
... ...
@abstractmethod @abstractmethod
def quantify(self, instances): def predict(self, X):
""" """
Generate class prevalence estimates for the sample's instances Generate class prevalence estimates for the sample's instances
:param instances: array-like :param X: array-like, the test instances
:return: `np.ndarray` of shape `(n_classes,)` with class prevalence estimates.
"""
...
def quantify(self, X):
"""
Alias to :meth:`predict`, for old compatibility
:param X: array-like
:return: `np.ndarray` of shape `(n_classes,)` with class prevalence estimates. :return: `np.ndarray` of shape `(n_classes,)` with class prevalence estimates.
""" """
... ...
@ -45,8 +55,9 @@ class BinaryQuantifier(BaseQuantifier):
(typically, to be interpreted as one class and its complement). (typically, to be interpreted as one class and its complement).
""" """
def _check_binary(self, data: LabelledCollection, quantifier_name): def _check_binary(self, y, quantifier_name):
assert data.binary, f'{quantifier_name} works only on problems of binary classification. ' \ n_classes = len(set(y))
assert n_classes==2, f'{quantifier_name} works only on problems of binary classification. ' \
f'Use the class OneVsAll to enable {quantifier_name} work on single-label data.' f'Use the class OneVsAll to enable {quantifier_name} work on single-label data.'
@ -66,7 +77,7 @@ def newOneVsAll(binary_quantifier: BaseQuantifier, n_jobs=None):
class OneVsAllGeneric(OneVsAll, BaseQuantifier): class OneVsAllGeneric(OneVsAll, BaseQuantifier):
""" """
Allows any binary quantifier to perform quantification on single-label datasets. The method maintains one binary Allows any binary quantifier to perform quantification on single-label datasets. The method maintains one binary
quantifier for each class, and then l1-normalizes the outputs so that the class prevelence values sum up to 1. quantifier for each class, and then l1-normalizes the outputs so that the class prevalence values sum up to 1.
""" """
def __init__(self, binary_quantifier: BaseQuantifier, n_jobs=None): def __init__(self, binary_quantifier: BaseQuantifier, n_jobs=None):
@ -93,8 +104,8 @@ class OneVsAllGeneric(OneVsAll, BaseQuantifier):
) )
) )
def quantify(self, instances): def predict(self, X):
prevalences = self._parallel(self._delayed_binary_predict, instances) prevalences = self._parallel(self._delayed_binary_predict, X)
return qp.functional.normalize_prevalence(prevalences) return qp.functional.normalize_prevalence(prevalences)
@property @property
@ -102,7 +113,7 @@ class OneVsAllGeneric(OneVsAll, BaseQuantifier):
return sorted(self.dict_binary_quantifiers.keys()) return sorted(self.dict_binary_quantifiers.keys())
def _delayed_binary_predict(self, c, X): def _delayed_binary_predict(self, c, X):
return self.dict_binary_quantifiers[c].quantify(X)[1] return self.dict_binary_quantifiers[c].predict(X)[1]
def _delayed_binary_fit(self, c, data): def _delayed_binary_fit(self, c, data):
bindata = LabelledCollection(data.instances, data.labels == c, classes=[False, True]) bindata = LabelledCollection(data.instances, data.labels == c, classes=[False, True])

View File

@ -72,12 +72,12 @@ class MedianEstimator2(BinaryQuantifier):
def _delayed_predict(self, args): def _delayed_predict(self, args):
model, instances = args model, instances = args
return model.quantify(instances) return model.predict(instances)
def quantify(self, instances): def predict(self, X):
prev_preds = qp.util.parallel( prev_preds = qp.util.parallel(
self._delayed_predict, self._delayed_predict,
((model, instances) for model in self.models), ((model, X) for model in self.models),
seed=qp.environ.get('_R_SEED', None), seed=qp.environ.get('_R_SEED', None),
n_jobs=self.n_jobs n_jobs=self.n_jobs
) )
@ -174,12 +174,12 @@ class MedianEstimator(BinaryQuantifier):
def _delayed_predict(self, args): def _delayed_predict(self, args):
model, instances = args model, instances = args
return model.quantify(instances) return model.predict(instances)
def quantify(self, instances): def predict(self, X):
prev_preds = qp.util.parallel( prev_preds = qp.util.parallel(
self._delayed_predict, self._delayed_predict,
((model, instances) for model in self.models), ((model, X) for model in self.models),
seed=qp.environ.get('_R_SEED', None), seed=qp.environ.get('_R_SEED', None),
n_jobs=self.n_jobs, n_jobs=self.n_jobs,
asarray=False asarray=False
@ -294,15 +294,15 @@ class Ensemble(BaseQuantifier):
self._sout('Fit [Done]') self._sout('Fit [Done]')
return self return self
def quantify(self, instances): def predict(self, X):
predictions = np.asarray( predictions = np.asarray(
qp.util.parallel(_delayed_quantify, ((Qi, instances) for Qi in self.ensemble), n_jobs=self.n_jobs) qp.util.parallel(_delayed_quantify, ((Qi, X) for Qi in self.ensemble), n_jobs=self.n_jobs)
) )
if self.policy == 'ptr': if self.policy == 'ptr':
predictions = self._ptr_policy(predictions) predictions = self._ptr_policy(predictions)
elif self.policy == 'ds': elif self.policy == 'ds':
predictions = self._ds_policy(predictions, instances) predictions = self._ds_policy(predictions, X)
predictions = np.mean(predictions, axis=0) predictions = np.mean(predictions, axis=0)
return F.normalize_prevalence(predictions) return F.normalize_prevalence(predictions)
@ -470,7 +470,7 @@ def _delayed_new_instance(args):
def _delayed_quantify(args): def _delayed_quantify(args):
quantifier, instances = args quantifier, instances = args
return quantifier[0].quantify(instances) return quantifier[0].predict(instances)
def _draw_simplex(ndim, min_val, max_trials=100): def _draw_simplex(ndim, min_val, max_trials=100):

View File

@ -30,11 +30,11 @@ class MaximumLikelihoodPrevalenceEstimation(BaseQuantifier):
self.estimated_prevalence = data.prevalence() self.estimated_prevalence = data.prevalence()
return self return self
def quantify(self, instances): def predict(self, X):
""" """
Ignores the input instances and returns, as the class prevalence estimantes, the training prevalence. Ignores the input instances and returns, as the class prevalence estimantes, the training prevalence.
:param instances: array-like (ignored) :param X: array-like (ignored)
:return: the class prevalence seen during training :return: the class prevalence seen during training
""" """
return self.estimated_prevalence return self.estimated_prevalence
@ -122,20 +122,20 @@ class DMx(BaseQuantifier):
return self return self
def quantify(self, instances): def predict(self, X):
""" """
Searches for the mixture model parameter (the sought prevalence values) that yields a validation distribution Searches for the mixture model parameter (the sought prevalence values) that yields a validation distribution
(the mixture) that best matches the test distribution, in terms of the divergence measure of choice. (the mixture) that best matches the test distribution, in terms of the divergence measure of choice.
The matching is computed as the average dissimilarity (in terms of the dissimilarity measure of choice) The matching is computed as the average dissimilarity (in terms of the dissimilarity measure of choice)
between all feature-specific discrete distributions. between all feature-specific discrete distributions.
:param instances: instances in the sample :param X: instances in the sample
:return: a vector of class prevalence estimates :return: a vector of class prevalence estimates
""" """
assert instances.shape[1] == self.nfeats, f'wrong shape; expected {self.nfeats}, found {instances.shape[1]}' assert X.shape[1] == self.nfeats, f'wrong shape; expected {self.nfeats}, found {X.shape[1]}'
test_distribution = self.__get_distributions(instances) test_distribution = self.__get_distributions(X)
divergence = get_divergence(self.divergence) divergence = get_divergence(self.divergence)
n_classes, n_feats, nbins = self.validation_distribution.shape n_classes, n_feats, nbins = self.validation_distribution.shape
def loss(prev): def loss(prev):
@ -163,8 +163,8 @@ class ReadMe(BaseQuantifier):
X = self.vectorizer.fit_transform(X) X = self.vectorizer.fit_transform(X)
self.class_conditional_X = {i: X[y==i] for i in range(data.classes_)} self.class_conditional_X = {i: X[y==i] for i in range(data.classes_)}
def quantify(self, instances): def predict(self, X):
X = self.vectorizer.transform(instances) X = self.vectorizer.transform(X)
# number of features # number of features
num_docs, num_feats = X.shape num_docs, num_feats = X.shape

View File

@ -275,15 +275,15 @@ class GridSearchQ(BaseQuantifier):
return self return self
def quantify(self, instances): def predict(self, X):
"""Estimate class prevalence values using the best model found after calling the :meth:`fit` method. """Estimate class prevalence values using the best model found after calling the :meth:`fit` method.
:param instances: sample contanining the instances :param X: sample contanining the instances
:return: a ndarray of shape `(n_classes)` with class prevalence estimates as according to the best model found :return: a ndarray of shape `(n_classes)` with class prevalence estimates as according to the best model found
by the model selection process. by the model selection process.
""" """
assert hasattr(self, 'best_model_'), 'quantify called before fit' assert hasattr(self, 'best_model_'), 'quantify called before fit'
return self.best_model().quantify(instances) return self.best_model().predict(X)
def set_params(self, **parameters): def set_params(self, **parameters):
"""Sets the hyper-parameters to explore. """Sets the hyper-parameters to explore.
@ -365,7 +365,7 @@ def cross_val_predict(quantifier: BaseQuantifier, data: LabelledCollection, nfol
for train, test in data.kFCV(nfolds=nfolds, random_state=random_state): for train, test in data.kFCV(nfolds=nfolds, random_state=random_state):
quantifier.fit(train) quantifier.fit(train)
fold_prev = quantifier.quantify(test.X) fold_prev = quantifier.predict(test.X)
rel_size = 1. * len(test) / len(data) rel_size = 1. * len(test) / len(data)
total_prev += fold_prev*rel_size total_prev += fold_prev*rel_size

View File

@ -18,7 +18,7 @@ class TestDatasets(unittest.TestCase):
q = self.new_quantifier() q = self.new_quantifier()
print(f'testing method {q} in {dataset.name}...', end='') print(f'testing method {q} in {dataset.name}...', end='')
q.fit(dataset.training) q.fit(dataset.training)
estim_prevalences = q.quantify(dataset.test.instances) estim_prevalences = q.predict(dataset.test.instances)
self.assertTrue(F.check_prevalence_vector(estim_prevalences)) self.assertTrue(F.check_prevalence_vector(estim_prevalences))
print(f'[done]') print(f'[done]')
@ -26,7 +26,7 @@ class TestDatasets(unittest.TestCase):
for X, p in gen(): for X, p in gen():
if vectorizer is not None: if vectorizer is not None:
X = vectorizer.transform(X) X = vectorizer.transform(X)
estim_prevalences = q.quantify(X) estim_prevalences = q.predict(X)
self.assertTrue(F.check_prevalence_vector(estim_prevalences)) self.assertTrue(F.check_prevalence_vector(estim_prevalences))
max_samples_test -= 1 max_samples_test -= 1
if max_samples_test == 0: if max_samples_test == 0:

View File

@ -41,8 +41,8 @@ class EvalTestCase(unittest.TestCase):
def __init__(self, cls): def __init__(self, cls):
self.emq = EMQ(cls) self.emq = EMQ(cls)
def quantify(self, instances): def predict(self, X):
return self.emq.quantify(instances) return self.emq.predict(X)
def fit(self, data): def fit(self, data):
self.emq.fit(data) self.emq.fit(data)

View File

@ -51,7 +51,7 @@ class TestMethods(unittest.TestCase):
q = model(learner) q = model(learner)
print('testing', q) print('testing', q)
q.fit(dataset.training, fit_classifier=False) q.fit(dataset.training, fit_classifier=False)
estim_prevalences = q.quantify(dataset.test.X) estim_prevalences = q.predict(dataset.test.X)
self.assertTrue(check_prevalence_vector(estim_prevalences)) self.assertTrue(check_prevalence_vector(estim_prevalences))
def test_non_aggregative(self): def test_non_aggregative(self):
@ -65,7 +65,7 @@ class TestMethods(unittest.TestCase):
q = model() q = model()
print(f'testing {q} on dataset {dataset.name}') print(f'testing {q} on dataset {dataset.name}')
q.fit(dataset.training) q.fit(dataset.training)
estim_prevalences = q.quantify(dataset.test.X) estim_prevalences = q.predict(dataset.test.X)
self.assertTrue(check_prevalence_vector(estim_prevalences)) self.assertTrue(check_prevalence_vector(estim_prevalences))
def test_ensembles(self): def test_ensembles(self):
@ -81,7 +81,7 @@ class TestMethods(unittest.TestCase):
print(f'testing {base_quantifier} on dataset {dataset.name} with {policy=}') print(f'testing {base_quantifier} on dataset {dataset.name} with {policy=}')
ensemble = Ensemble(quantifier=base_quantifier, size=3, policy=policy, n_jobs=-1) ensemble = Ensemble(quantifier=base_quantifier, size=3, policy=policy, n_jobs=-1)
ensemble.fit(dataset.training) ensemble.fit(dataset.training)
estim_prevalences = ensemble.quantify(dataset.test.instances) estim_prevalences = ensemble.predict(dataset.test.instances)
self.assertTrue(check_prevalence_vector(estim_prevalences)) self.assertTrue(check_prevalence_vector(estim_prevalences))
def test_quanet(self): def test_quanet(self):
@ -107,7 +107,7 @@ class TestMethods(unittest.TestCase):
model = QuaNet(learner, device='cpu', n_epochs=2, tr_iter_per_poch=10, va_iter_per_poch=10, patience=2) model = QuaNet(learner, device='cpu', n_epochs=2, tr_iter_per_poch=10, va_iter_per_poch=10, patience=2)
model.fit(dataset.training) model.fit(dataset.training)
estim_prevalences = model.quantify(dataset.test.instances) estim_prevalences = model.predict(dataset.test.instances)
self.assertTrue(check_prevalence_vector(estim_prevalences)) self.assertTrue(check_prevalence_vector(estim_prevalences))
def test_composable(self): def test_composable(self):
@ -115,7 +115,7 @@ class TestMethods(unittest.TestCase):
for q in COMPOSABLE_METHODS: for q in COMPOSABLE_METHODS:
print('testing', q) print('testing', q)
q.fit(dataset.training) q.fit(dataset.training)
estim_prevalences = q.quantify(dataset.test.X) estim_prevalences = q.predict(dataset.test.X)
self.assertTrue(check_prevalence_vector(estim_prevalences)) self.assertTrue(check_prevalence_vector(estim_prevalences))

View File

@ -17,13 +17,13 @@ class TestReplicability(unittest.TestCase):
with qp.util.temp_seed(0): with qp.util.temp_seed(0):
lr = LogisticRegression(random_state=0, max_iter=10000) lr = LogisticRegression(random_state=0, max_iter=10000)
pacc = PACC(lr) pacc = PACC(lr)
prev = pacc.fit(dataset.training).quantify(dataset.test.X) prev = pacc.fit(dataset.training).predict(dataset.test.X)
str_prev1 = strprev(prev, prec=5) str_prev1 = strprev(prev, prec=5)
with qp.util.temp_seed(0): with qp.util.temp_seed(0):
lr = LogisticRegression(random_state=0, max_iter=10000) lr = LogisticRegression(random_state=0, max_iter=10000)
pacc = PACC(lr) pacc = PACC(lr)
prev2 = pacc.fit(dataset.training).quantify(dataset.test.X) prev2 = pacc.fit(dataset.training).predict(dataset.test.X)
str_prev2 = strprev(prev2, prec=5) str_prev2 = strprev(prev2, prec=5)
self.assertEqual(str_prev1, str_prev2) self.assertEqual(str_prev1, str_prev2)
@ -85,17 +85,17 @@ class TestReplicability(unittest.TestCase):
with qp.util.temp_seed(10): with qp.util.temp_seed(10):
pacc = PACC(LogisticRegression(), val_split=2, n_jobs=2) pacc = PACC(LogisticRegression(), val_split=2, n_jobs=2)
pacc.fit(train, val_split=0.5) pacc.fit(train, val_split=0.5)
prev1 = F.strprev(pacc.quantify(test.instances)) prev1 = F.strprev(pacc.predict(test.instances))
with qp.util.temp_seed(0): with qp.util.temp_seed(0):
pacc = PACC(LogisticRegression(), val_split=2, n_jobs=2) pacc = PACC(LogisticRegression(), val_split=2, n_jobs=2)
pacc.fit(train, val_split=0.5) pacc.fit(train, val_split=0.5)
prev2 = F.strprev(pacc.quantify(test.instances)) prev2 = F.strprev(pacc.predict(test.instances))
with qp.util.temp_seed(0): with qp.util.temp_seed(0):
pacc = PACC(LogisticRegression(), val_split=2, n_jobs=2) pacc = PACC(LogisticRegression(), val_split=2, n_jobs=2)
pacc.fit(train, val_split=0.5) pacc.fit(train, val_split=0.5)
prev3 = F.strprev(pacc.quantify(test.instances)) prev3 = F.strprev(pacc.predict(test.instances))
print(prev1) print(prev1)
print(prev2) print(prev2)

16
testing_refactor.py Normal file
View File

@ -0,0 +1,16 @@
import quapy as qp
from method.aggregative import *
datasets = qp.datasets.UCI_MULTICLASS_DATASETS[1]
data = qp.datasets.fetch_UCIMulticlassDataset(datasets)
train, test = data.train_test
quant = EMQ()
quant.fit(*train.Xy)
prev = quant.predict(test.X)
print(prev)
# test CC, prevent from doing 5FCV for nothing
# test PACC o PCC with LinearSVC; removing "adapt_if_necessary" form _check_classifier