refactoring w/o labelled collection

This commit is contained in:
Alejandro Moreo Fernandez 2025-04-20 22:05:46 +02:00
parent c79b76516c
commit 075be93a23
22 changed files with 416 additions and 248 deletions

View File

@ -53,7 +53,7 @@ training, test = dataset.train_test
model = qp.method.aggregative.ACC()
model.fit(training)
estim_prevalence = model.quantify(test.X)
estim_prevalence = model.predict(test.X)
true_prevalence = test.prevalence()
error = qp.error.mae(true_prevalence, estim_prevalence)

View File

@ -1,3 +1,27 @@
Para quitar el labelledcollection de los métodos:
- El follón viene por la semántica confusa de fit en agregativos, que recibe 3 parámetros:
- data: LabelledCollection, que puede ser:
- el training set si hay que entrenar el clasificador
- None si no hay que entregar el clasificador
- el validation, que entra en conflicto con val_split, si no hay que entrenar clasificador
- fit_classifier: dice si hay que entrenar el clasificador o no, y estos cambia la semántica de los otros
- val_split: que puede ser:
- un número: el número de kfcv, lo cual implica fit_classifier=True y data=todo el training set
- una fración en [0,1]: que indica la parte que usamos para validation; implica fit_classifier=True y data=train+val
- un labelled collection: el conjunto de validación específico; no implica fit_classifier=True ni False
- La forma de quitar la dependencia de los métodos con LabelledCollection debería ser así:
- En el constructor se dice si el clasificador que se recibe por parámetro hay que entrenarlo o ya está entrenado;
es decir, hay un fit_classifier=True o False.
- fit_classifier=True:
- data en fit es todo el training incluyendo el validation y todo
- val_split:
- int: número de folds en kfcv
- proporción en [0,1]
- fit_classifier=False:
- [TODO] document confidence in manuals
- [TODO] Test the return_type="index" in protocols and finish the "distributing_samples.py" example
- [TODO] Add EDy (an implementation is available at quantificationlib)

View File

@ -32,7 +32,7 @@ dataset = qp.datasets.fetch_twitter('semeval16')
model = qp.method.aggregative.ACC(LogisticRegression())
model.fit(dataset.training)
estim_prevalence = model.quantify(dataset.test.instances)
estim_prevalence = model.predict(dataset.test.instances)
true_prevalence = dataset.test.prevalence()
error = qp.error.mae(true_prevalence, estim_prevalence)

View File

@ -132,7 +132,7 @@ svm = LinearSVC()
# (an alias is available in qp.method.aggregative.ClassifyAndCount)
model = qp.method.aggregative.CC(svm)
model.fit(training)
estim_prevalence = model.quantify(test.instances)
estim_prevalence = model.predict(test.instances)
```
The same code could be used to instantiate an ACC, by simply replacing
@ -172,7 +172,7 @@ The following code illustrates the case in which PCC is used:
```python
model = qp.method.aggregative.PCC(svm)
model.fit(training)
estim_prevalence = model.quantify(test.instances)
estim_prevalence = model.predict(test.instances)
print('classifier:', model.classifier)
```
In this case, QuaPy will print:
@ -263,7 +263,7 @@ dataset = qp.datasets.fetch_twitter('hcr', pickle=True)
model = qp.method.aggregative.EMQ(LogisticRegression())
model.fit(dataset.training)
estim_prevalence = model.quantify(dataset.test.instances)
estim_prevalence = model.predict(dataset.test.instances)
```
_New in v0.1.7_: EMQ now accepts two new parameters in the construction method, namely
@ -299,6 +299,7 @@ HDy was proposed as a binary classifier and the implementation
provided in QuaPy accepts only binary datasets.
The following code shows an example of use:
```python
import quapy as qp
from sklearn.linear_model import LogisticRegression
@ -309,7 +310,7 @@ qp.data.preprocessing.text2tfidf(dataset, min_df=5, inplace=True)
model = qp.method.aggregative.HDy(LogisticRegression())
model.fit(dataset.training)
estim_prevalence = model.quantify(dataset.test.instances)
estim_prevalence = model.predict(dataset.test.instances)
```
_New in v0.1.7:_ QuaPy now provides an implementation of the generalized
@ -411,7 +412,7 @@ qp.environ['SVMPERF_HOME'] = '../svm_perf_quantification'
model = newOneVsAll(SVMQ(), n_jobs=-1) # run them on parallel
model.fit(dataset.training)
estim_prevalence = model.quantify(dataset.test.instances)
estim_prevalence = model.predict(dataset.test.instances)
```
Check the examples on [explicit_loss_minimization](https://github.com/HLT-ISTI/QuaPy/blob/devel/examples/5.explicit_loss_minimization.py)
@ -531,7 +532,7 @@ dataset = qp.datasets.fetch_UCIBinaryDataset('haberman')
model = Ensemble(quantifier=ACC(LogisticRegression()), size=30, policy='ave', n_jobs=-1)
model.fit(dataset.training)
estim_prevalence = model.quantify(dataset.test.instances)
estim_prevalence = model.predict(dataset.test.instances)
```
Other aggregation policies implemented in QuaPy include:
@ -579,6 +580,6 @@ learner = NeuralClassifierTrainer(cnn, device='cuda')
# train QuaNet
model = QuaNet(learner, device='cuda')
model.fit(dataset.training)
estim_prevalence = model.quantify(dataset.test.instances)
estim_prevalence = model.predict(dataset.test.instances)
```

View File

@ -41,7 +41,7 @@ pacc.fit(train)
# let's now test our quantifier on the test data (of course, we should not use the test labels y at this point, only X)
X_test = test.X
estim_prevalence = pacc.quantify(X_test)
estim_prevalence = pacc.predict(X_test)
print(f'estimated test prevalence = {F.strprev(estim_prevalence)}')
print(f'true test prevalence = {F.strprev(test.prevalence())}')

View File

@ -123,7 +123,7 @@ def _get_estimate(estimator_class, training: LabelledCollection, test: np.ndarra
"""Auxiliary method for running ACC and PACC."""
estimator = estimator_class(get_random_forest())
estimator.fit(training)
return estimator.quantify(test)
return estimator.predict(test)
def train_and_plot_bayesian_quantification(ax: plt.Axes, training: LabelledCollection, test: LabelledCollection) -> None:
@ -133,7 +133,7 @@ def train_and_plot_bayesian_quantification(ax: plt.Axes, training: LabelledColle
quantifier.fit(training)
# Obtain mean prediction
mean_prediction = quantifier.quantify(test.X)
mean_prediction = quantifier.predict(test.X)
mae = qp.error.mae(test.prevalence(), mean_prediction)
x_ax = np.arange(training.n_classes)
ax.plot(x_ax, mean_prediction, c="salmon", linewidth=2, linestyle=":", label="Bayesian")

View File

@ -39,10 +39,10 @@ class MyQuantifier(BaseQuantifier):
return self
# in general, we would need to implement the method quantify(self, instances); this would amount to:
def quantify(self, instances):
def predict(self, X):
assert hasattr(self.classifier, 'predict_proba'), \
'the underlying classifier is not probabilistic! [abort]'
posterior_probabilities = self.classifier.predict_proba(instances)
posterior_probabilities = self.classifier.predict_proba(X)
positive_probabilities = posterior_probabilities[:, 1]
crisp_decisions = positive_probabilities > self.alpha
pos_prev = crisp_decisions.mean()

View File

@ -27,7 +27,7 @@ quantifier = QuaNet(cnn_classifier, device='cuda')
quantifier.fit(train, fit_classifier=False)
# prediction and evaluation
estim_prevalence = quantifier.quantify(test.instances)
estim_prevalence = quantifier.predict(test.instances)
mae = qp.error.mae(test.prevalence(), estim_prevalence)
print(f'true prevalence: {F.strprev(test.prevalence())}')

View File

@ -14,7 +14,7 @@ from . import model_selection
from . import classification
import os
__version__ = '0.1.10'
__version__ = '0.1.10r'
environ = {
'SAMPLE_SIZE': None,
@ -24,7 +24,7 @@ environ = {
'PAD_INDEX': 1,
'SVMPERF_HOME': './svm_perf_quantification',
'N_JOBS': int(os.getenv('N_JOBS', 1)),
'DEFAULT_CLS': LogisticRegression(max_iter=3000)
'DEFAULT_CLS': LogisticRegression()
}

View File

@ -232,11 +232,11 @@ class LabelledCollection:
:return: two instances of :class:`LabelledCollection`, the first one with `train_prop` elements, and the
second one with `1-train_prop` elements
"""
tr_docs, te_docs, tr_labels, te_labels = train_test_split(
tr_X, te_X, tr_y, te_y = train_test_split(
self.instances, self.labels, train_size=train_prop, stratify=self.labels, random_state=random_state
)
training = LabelledCollection(tr_docs, tr_labels, classes=self.classes_)
test = LabelledCollection(te_docs, te_labels, classes=self.classes_)
training = LabelledCollection(tr_X, tr_y, classes=self.classes_)
test = LabelledCollection(te_X, te_y, classes=self.classes_)
return training, test
def split_random(self, train_prop=0.6, random_state=None):

View File

@ -63,7 +63,7 @@ def prediction(
protocol_with_predictions = protocol.on_preclassified_instances(pre_classified)
return __prediction_helper(model.aggregate, protocol_with_predictions, verbose)
else:
return __prediction_helper(model.quantify, protocol, verbose)
return __prediction_helper(model.predict, protocol, verbose)
def __prediction_helper(quantification_fn, protocol: AbstractProtocol, verbose=False):

View File

@ -38,7 +38,7 @@ class QuaNetTrainer(BaseQuantifier):
>>> # train QuaNet (QuaNet is an alias to QuaNetTrainer)
>>> model = QuaNet(classifier, qp.environ['SAMPLE_SIZE'], device='cuda')
>>> model.fit(dataset.training)
>>> estim_prevalence = model.quantify(dataset.test.instances)
>>> estim_prevalence = model.predict(dataset.test.instances)
:param classifier: an object implementing `fit` (i.e., that can be trained on labelled data),
`predict_proba` (i.e., that can generate posterior probabilities of unlabelled examples) and
@ -201,9 +201,9 @@ class QuaNetTrainer(BaseQuantifier):
return prevs_estim
def quantify(self, instances):
posteriors = self.classifier.predict_proba(instances)
embeddings = self.classifier.transform(instances)
def predict(self, X):
posteriors = self.classifier.predict_proba(X)
embeddings = self.classifier.transform(X)
quant_estims = self._get_aggregative_estims(posteriors)
self.quanet.eval()
with torch.no_grad():

View File

@ -5,8 +5,10 @@ import numpy as np
from abstention.calibration import NoBiasVectorScaling, TempScaling, VectorScaling
from sklearn.base import BaseEstimator
from sklearn.calibration import CalibratedClassifierCV
from sklearn.exceptions import NotFittedError
from sklearn.metrics import confusion_matrix
from sklearn.model_selection import cross_val_predict
from sklearn.model_selection import cross_val_predict, train_test_split
from sklearn.utils.validation import check_is_fitted
import quapy as qp
import quapy.functional as F
@ -14,6 +16,7 @@ from quapy.functional import get_divergence
from quapy.classification.svmperf import SVMperf
from quapy.data import LabelledCollection
from quapy.method.base import BaseQuantifier, BinaryQuantifier, OneVsAllGeneric
from quapy.method import _bayesian
# Abstract classes
@ -35,18 +38,53 @@ class AggregativeQuantifier(BaseQuantifier, ABC):
and :meth:`aggregate`.
"""
val_split_ = None
def __init__(self, classifier: Union[None,BaseEstimator], fit_classifier:bool=True, val_split:Union[int,float,tuple,None]=5):
self.classifier = qp._get_classifier(classifier)
self.fit_classifier = fit_classifier
self.val_split = val_split
@property
def val_split(self):
return self.val_split_
# basic type checks
assert hasattr(self.classifier, 'fit'), \
f'the classifier does not implement "fit"'
@val_split.setter
def val_split(self, val_split):
if isinstance(val_split, LabelledCollection):
print('warning: setting val_split with a LabelledCollection will be inefficient in'
'model selection. Rather pass the LabelledCollection at fit time')
self.val_split_ = val_split
assert isinstance(fit_classifier, bool), \
f'unexpected type for {fit_classifier=}; must be True or False'
if isinstance(val_split, int):
assert val_split > 1, \
(f'when {val_split=} is indicated as an integer, it represents the number of folds in a kFCV '
f'and must thus be >1')
assert fit_classifier, (f'when {val_split=} is indicated as an integer (the number of folds for kFCV) '
f'the parameter {fit_classifier=} must be True')
elif isinstance(val_split, float):
assert 0 < val_split < 1, \
(f'when {val_split=} is indicated as a float, it represents the fraction of training instances '
f'to be used for validation, and must thus be in the range (0,1)')
assert fit_classifier, (f'when {val_split=} is indicated as a float (the fraction of training instances '
f'to be used for validation), the parameter {fit_classifier=} must be True')
elif isinstance(val_split, tuple):
assert len(val_split) == 2, \
(f'when {val_split=} is indicated as a tuple, it represents the collection (X,y) on which the '
f'validation must be performed, but this seems to have different cardinality')
elif val_split is None:
pass
else:
raise ValueError(f'unexpected type for {val_split=}')
# classifier is fitted?
try:
check_is_fitted(self.classifier)
fitted = True
except NotFittedError:
fitted = False
# consistency checks: fit_classifier?
if self.fit_classifier:
if fitted:
raise RuntimeWarning(f'the classifier is already fitted, by {fit_classifier=} was requested')
else:
assert fitted, (f'{fit_classifier=} requires the classifier to be already trained, '
f'but this does not seem to be')
def _check_init_parameters(self):
"""
@ -58,20 +96,36 @@ class AggregativeQuantifier(BaseQuantifier, ABC):
"""
pass
def _check_non_empty_classes(self, data: LabelledCollection):
def _check_non_empty_classes(self, y):
"""
Asserts all classes have positive instances.
:param data: LabelledCollection
:param labels: array-like of shape `(n_instances,)` with the label for each instance
:param classes: the class labels. This is needed in order to correctly compute the prevalence vector even when
some classes have no examples.
:return: Nothing. May raise an exception.
"""
sample_prevs = data.prevalence()
sample_prevs = F.prevalence_from_labels(y, self.classes_)
empty_classes = np.argwhere(sample_prevs == 0).flatten()
if len(empty_classes) > 0:
empty_class_names = data.classes_[empty_classes]
empty_class_names = self.classes_[empty_classes]
raise ValueError(f'classes {empty_class_names} have no training examples')
def fit(self, data: LabelledCollection, fit_classifier=True, val_split=None):
def fit(self, X, y):
"""
Trains the aggregative quantifier. This comes down to training a classifier (if requested) and an
aggregation function.
:param X: array-like, the training instances
:param y: array-like, the labels
:return: self
"""
self._check_init_parameters()
classif_predictions = self.classifier_fit_predict(X, y)
self.aggregation_fit(classif_predictions)
return self
def fit_depr(self, data: LabelledCollection, fit_classifier=True, val_split=None):
"""
Trains the aggregative quantifier. This comes down to training a classifier and an aggregation function.
@ -88,94 +142,55 @@ class AggregativeQuantifier(BaseQuantifier, ABC):
:return: self
"""
self._check_init_parameters()
classif_predictions = self.classifier_fit_predict(data, fit_classifier, predict_on=val_split)
self.aggregation_fit(classif_predictions, data)
classif_predictions = self.classifier_fit_predict_depr(data, fit_classifier, predict_on=val_split)
self.aggregation_fit_depr(classif_predictions, data)
return self
def classifier_fit_predict(self, data: LabelledCollection, fit_classifier=True, predict_on=None):
def classifier_fit_predict(self, X, y):
"""
Trains the classifier if requested (`fit_classifier=True`) and generate the necessary predictions to
train the aggregation function.
:param data: a :class:`quapy.data.base.LabelledCollection` consisting of the training data
:param fit_classifier: whether to train the learner (default is True). Set to False if the
learner has been trained outside the quantifier.
:param predict_on: specifies the set on which predictions need to be issued. This parameter can
be specified as None (default) to indicate no prediction is needed; a float in (0, 1) to
indicate the proportion of instances to be used for predictions (the remainder is used for
training); an integer >1 to indicate that the predictions must be generated via k-fold
cross-validation, using this integer as k; or the data sample itself on which to generate
the predictions.
:param X: array-like, the training instances
:param y: array-like, the labels
"""
assert isinstance(fit_classifier, bool), 'unexpected type for "fit_classifier", must be boolean'
self._check_classifier()
self._check_classifier(adapt_if_necessary=(self._classifier_method() == 'predict_proba'))
# self._check_non_empty_classes(y)
if fit_classifier:
self._check_non_empty_classes(data)
if predict_on is None:
if not fit_classifier:
predict_on = data
if isinstance(self.val_split, LabelledCollection) and self.val_split!=predict_on:
raise ValueError(f'{fit_classifier=} but a LabelledCollection was provided as val_split '
f'in __init__ that is not the same as the LabelledCollection provided in fit.')
if predict_on is None:
predict_on = self.val_split
if predict_on is None:
if fit_classifier:
self.classifier.fit(*data.Xy)
predictions = None
elif isinstance(predict_on, float):
if fit_classifier:
if not (0. < predict_on < 1.):
raise ValueError(f'proportion {predict_on=} out of range, must be in (0,1)')
train, val = data.split_stratified(train_prop=(1 - predict_on))
self.classifier.fit(*train.Xy)
predictions = LabelledCollection(self.classify(val.X), val.y, classes=data.classes_)
else:
raise ValueError(f'wrong type for predict_on: since fit_classifier=False, '
f'the set on which predictions have to be issued must be '
f'explicitly indicated')
elif isinstance(predict_on, LabelledCollection):
if fit_classifier:
self.classifier.fit(*data.Xy)
predictions = LabelledCollection(self.classify(predict_on.X), predict_on.y, classes=predict_on.classes_)
elif isinstance(predict_on, int):
if fit_classifier:
if predict_on <= 1:
raise ValueError(f'invalid value {predict_on} in fit. '
f'Specify a integer >1 for kFCV estimation.')
else:
if isinstance(self.val_split, int):
assert self.fit_classifier, f'unexpected value for {self.fit_classifier=}'
num_folds = self.val_split
n_jobs = self.n_jobs if hasattr(self, 'n_jobs') else qp._get_njobs(None)
predictions = cross_val_predict(
self.classifier, *data.Xy, cv=predict_on, n_jobs=n_jobs, method=self._classifier_method())
predictions = LabelledCollection(predictions, data.y, classes=data.classes_)
self.classifier.fit(*data.Xy)
predictions = cross_val_predict(self.classifier, X, y, cv=num_folds, n_jobs=n_jobs, method=self._classifier_method())
yval = y
self.classifier.fit(X, y)
elif isinstance(self.val_split, float):
assert self.fit_classifier, f'unexpected value for {self.fit_classifier=}'
train_prop = 1. - self.val_split
Xtr, Xval, ytr, yval = train_test_split(X, y, train_size=train_prop, stratify=y)
self.classifier.fit(Xtr, ytr)
predictions = self.classify(Xval)
elif isinstance(self.val_split, tuple):
Xval, yval = self.val_split
if self.fit_classifier:
self.classifier.fit(X, y)
elif self.val_split is None:
if self.fit_classifier:
self.classifier.fit(X, y)
predictions, yval = None, None
else:
raise ValueError(f'wrong type for predict_on: since fit_classifier=False, '
f'the set on which predictions have to be issued must be '
f'explicitly indicated')
raise ValueError(f'unexpected type for {self.val_split=}')
else:
raise ValueError(
f'error: param "predict_on" ({type(predict_on)}) not understood; '
f'use either a float indicating the split proportion, or a '
f'tuple (X,y) indicating the validation partition')
return predictions
return predictions, yval
@abstractmethod
def aggregation_fit(self, classif_predictions: LabelledCollection, data: LabelledCollection):
def aggregation_fit(self, classif_predictions, **kwargs):
"""
Trains the aggregation function.
:param classif_predictions: a :class:`quapy.data.base.LabelledCollection` containing,
as instances, the predictions issued by the classifier and, as labels, the true labels
:param data: a :class:`quapy.data.base.LabelledCollection` consisting of the training data
:param classif_predictions: the classification predictions; whatever the method
:meth:`classify` returns
"""
...
@ -197,16 +212,16 @@ class AggregativeQuantifier(BaseQuantifier, ABC):
"""
self.classifier_ = classifier
def classify(self, instances):
def classify(self, X):
"""
Provides the label predictions for the given instances. The predictions should respect the format expected by
:meth:`aggregate`, e.g., posterior probabilities for probabilistic quantifiers, or crisp predictions for
non-probabilistic quantifiers. The default one is "decision_function".
:param instances: array-like of shape `(n_instances, n_features,)`
:param X: array-like of shape `(n_instances, n_features,)`
:return: np.ndarray of shape `(n_instances,)` with label predictions
"""
return getattr(self.classifier, self._classifier_method())(instances)
return getattr(self.classifier, self._classifier_method())(X)
def _classifier_method(self):
"""
@ -221,26 +236,26 @@ class AggregativeQuantifier(BaseQuantifier, ABC):
Guarantees that the underlying classifier implements the method required for issuing predictions, i.e.,
the method indicated by the :meth:`_classifier_method`
:param adapt_if_necessary: if True, the method will try to comply with the required specifications
:param adapt_if_necessary: unused unless overriden
"""
assert hasattr(self.classifier, self._classifier_method()), \
f"the method does not implement the required {self._classifier_method()} method"
def quantify(self, instances):
def predict(self, X):
"""
Generate class prevalence estimates for the sample's instances by aggregating the label predictions generated
by the classifier.
:param instances: array-like
:param X: array-like
:return: `np.ndarray` of shape `(n_classes)` with class prevalence estimates.
"""
classif_predictions = self.classify(instances)
classif_predictions = self.classify(X)
return self.aggregate(classif_predictions)
@abstractmethod
def aggregate(self, classif_predictions: np.ndarray):
"""
Implements the aggregation of label predictions.
Implements the aggregation of the classifier predictions.
:param classif_predictions: `np.ndarray` of label predictions
:return: `np.ndarray` of shape `(n_classes,)` with class prevalence estimates.
@ -324,9 +339,9 @@ class BinaryAggregativeQuantifier(AggregativeQuantifier, BinaryQuantifier):
def neg_label(self):
return self.classifier.classes_[0]
def fit(self, data: LabelledCollection, fit_classifier=True, val_split=None):
self._check_binary(data, self.__class__.__name__)
return super().fit(data, fit_classifier, val_split)
def fit(self, X, y):
self._check_binary(y, self.__class__.__name__)
return super().fit(X, y)
# Methods
@ -338,16 +353,14 @@ class CC(AggregativeCrispQuantifier):
:param classifier: a sklearn's Estimator that generates a classifier
"""
def __init__(self, classifier: BaseEstimator = None, fit_classifier: bool = True):
super().__init__(classifier, fit_classifier, val_split=None)
def __init__(self, classifier: BaseEstimator=None):
self.classifier = qp._get_classifier(classifier)
def aggregation_fit(self, classif_predictions: LabelledCollection, data: LabelledCollection):
def aggregation_fit(self, classif_predictions):
"""
Nothing to do here!
:param classif_predictions: not used
:param data: not used
"""
pass
@ -369,15 +382,14 @@ class PCC(AggregativeSoftQuantifier):
:param classifier: a sklearn's Estimator that generates a classifier
"""
def __init__(self, classifier: BaseEstimator=None):
self.classifier = qp._get_classifier(classifier)
def __init__(self, classifier: BaseEstimator = None, fit_classifier: bool = True):
super().__init__(classifier, fit_classifier, val_split=None)
def aggregation_fit(self, classif_predictions: LabelledCollection, data: LabelledCollection):
def aggregation_fit(self, classif_predictions):
"""
Nothing to do here!
:param classif_predictions: not used
:param data: not used
"""
pass
@ -430,17 +442,18 @@ class ACC(AggregativeCrispQuantifier):
:param n_jobs: number of parallel workers
"""
def __init__(
self,
classifier: BaseEstimator = None,
fit_classifier=True,
val_split=5,
solver: Literal['minimize', 'exact', 'exact-raise', 'exact-cc'] = 'minimize',
method: Literal['inversion', 'invariant-ratio'] = 'inversion',
norm: Literal['clip', 'mapsimplex', 'condsoftmax'] = 'clip',
n_jobs=None,
):
self.classifier = qp._get_classifier(classifier)
self.val_split = val_split
super().__init__(classifier, fit_classifier, val_split)
self.n_jobs = qp._get_njobs(n_jobs)
self.solver = solver
self.method = method
@ -451,13 +464,14 @@ class ACC(AggregativeCrispQuantifier):
NORMALIZATIONS = ['clip', 'mapsimplex', 'condsoftmax', None]
@classmethod
def newInvariantRatioEstimation(cls, classifier: BaseEstimator, val_split=5, n_jobs=None):
def newInvariantRatioEstimation(cls, classifier: BaseEstimator, fit_classifier=True, val_split=5, n_jobs=None):
"""
Constructs a quantifier that implements the Invariant Ratio Estimator of
`Vaz et al. 2018 <https://jmlr.org/papers/v20/18-456.html>`_. This amounts
to setting method to 'invariant-ratio' and clipping to 'project'.
:param classifier: a sklearn's Estimator that generates a classifier
:param fit_classifier: bool, whether to fit the classifier or not
:param val_split: specifies the data used for generating classifier predictions. This specification
can be made as float in (0, 1) indicating the proportion of stratified held-out validation set to
be extracted from the training set; or as an integer (default 5), indicating that the predictions
@ -468,7 +482,7 @@ class ACC(AggregativeCrispQuantifier):
:param n_jobs: number of parallel workers
:return: an instance of ACC configured so that it implements the Invariant Ratio Estimator
"""
return ACC(classifier, val_split=val_split, method='invariant-ratio', norm='mapsimplex', n_jobs=n_jobs)
return ACC(classifier, fit_classifier=fit_classifier, val_split=val_split, method='invariant-ratio', norm='mapsimplex', n_jobs=n_jobs)
def _check_init_parameters(self):
if self.solver not in ACC.SOLVERS:
@ -478,7 +492,7 @@ class ACC(AggregativeCrispQuantifier):
if self.norm not in ACC.NORMALIZATIONS:
raise ValueError(f"unknown normalization; valid ones are {ACC.NORMALIZATIONS}")
def aggregation_fit(self, classif_predictions: LabelledCollection, data: LabelledCollection):
def aggregation_fit(self, classif_predictions):
"""
Estimates the misclassification rates.
@ -486,8 +500,8 @@ class ACC(AggregativeCrispQuantifier):
as instances, the label predictions issued by the classifier and, as labels, the true labels
:param data: a :class:`quapy.data.base.LabelledCollection` consisting of the training data
"""
pred_labels, true_labels = classif_predictions.Xy
self.cc = CC(self.classifier)
pred_labels, true_labels = classif_predictions
self.cc = CC(self.classifier, fit_classifier=False)
self.Pte_cond_estim_ = ACC.getPteCondEstim(self.classifier.classes_, true_labels, pred_labels)
@classmethod
@ -529,6 +543,8 @@ class PACC(AggregativeSoftQuantifier):
:param classifier: a sklearn's Estimator that generates a classifier
:param fit_classifier: bool, whether to fit the classifier or not
:param val_split: specifies the data used for generating classifier predictions. This specification
can be made as float in (0, 1) indicating the proportion of stratified held-out validation set to
be extracted from the training set; or as an integer (default 5), indicating that the predictions
@ -565,17 +581,18 @@ class PACC(AggregativeSoftQuantifier):
:param n_jobs: number of parallel workers
"""
def __init__(
self,
classifier: BaseEstimator = None,
fit_classifier=True,
val_split=5,
solver: Literal['minimize', 'exact', 'exact-raise', 'exact-cc'] = 'minimize',
method: Literal['inversion', 'invariant-ratio'] = 'inversion',
norm: Literal['clip', 'mapsimplex', 'condsoftmax'] = 'clip',
n_jobs=None
):
self.classifier = qp._get_classifier(classifier)
self.val_split = val_split
super().__init__(classifier, fit_classifier, val_split)
self.n_jobs = qp._get_njobs(n_jobs)
self.solver = solver
self.method = method
@ -589,7 +606,7 @@ class PACC(AggregativeSoftQuantifier):
if self.norm not in ACC.NORMALIZATIONS:
raise ValueError(f"unknown normalization; valid ones are {ACC.NORMALIZATIONS}")
def aggregation_fit(self, classif_predictions: LabelledCollection, data: LabelledCollection):
def aggregation_fit(self, classif_predictions):
"""
Estimates the misclassification rates
@ -597,8 +614,8 @@ class PACC(AggregativeSoftQuantifier):
as instances, the posterior probabilities issued by the classifier and, as labels, the true labels
:param data: a :class:`quapy.data.base.LabelledCollection` consisting of the training data
"""
posteriors, true_labels = classif_predictions.Xy
self.pcc = PCC(self.classifier)
posteriors, true_labels = classif_predictions
self.pcc = PCC(self.classifier, fit_classifier=False)
self.Pte_cond_estim_ = PACC.getPteCondEstim(self.classifier.classes_, true_labels, posteriors)
def aggregate(self, classif_posteriors):
@ -640,6 +657,7 @@ class EMQ(AggregativeSoftQuantifier):
and to recalibrate the posterior probabilities of the classifier.
:param classifier: a sklearn's Estimator that generates a classifier
:param fit_classifier: bool, whether to fit the classifier or not
:param val_split: specifies the data used for generating classifier predictions. This specification
can be made as float in (0, 1) indicating the proportion of stratified held-out validation set to
be extracted from the training set; or as an integer, indicating that the predictions
@ -663,15 +681,15 @@ class EMQ(AggregativeSoftQuantifier):
MAX_ITER = 1000
EPSILON = 1e-4
def __init__(self, classifier: BaseEstimator=None, val_split=None, exact_train_prev=True, recalib=None, n_jobs=None):
self.classifier = qp._get_classifier(classifier)
self.val_split = val_split
def __init__(self, classifier: BaseEstimator = None, fit_classifier=True, val_split=None, exact_train_prev=True, recalib=None,
n_jobs=None):
super().__init__(classifier, fit_classifier, val_split)
self.exact_train_prev = exact_train_prev
self.recalib = recalib
self.n_jobs = n_jobs
@classmethod
def EMQ_BCTS(cls, classifier: BaseEstimator, n_jobs=None):
def EMQ_BCTS(cls, classifier: BaseEstimator, fit_classifier=True, val_split=5, n_jobs=None):
"""
Constructs an instance of EMQ using the best configuration found in the `Alexandari et al. paper
<http://proceedings.mlr.press/v119/alexandari20a.html>`_, i.e., one that relies on Bias-Corrected Temperature
@ -682,7 +700,7 @@ class EMQ(AggregativeSoftQuantifier):
:param n_jobs: number of parallel workers.
:return: An instance of EMQ with BCTS
"""
return EMQ(classifier, val_split=5, exact_train_prev=False, recalib='bcts', n_jobs=n_jobs)
return EMQ(classifier, fit_classifier=fit_classifier, val_split=val_split, exact_train_prev=False, recalib='bcts', n_jobs=n_jobs)
def _check_init_parameters(self):
if self.val_split is not None:
@ -698,30 +716,30 @@ class EMQ(AggregativeSoftQuantifier):
f'indicating the number of folds for kFCV.')
self.val_split = 5
def classify(self, instances):
def classify(self, X):
"""
Provides the posterior probabilities for the given instances. If the classifier was required
to be recalibrated, then these posteriors are recalibrated accordingly.
:param instances: array-like of shape `(n_instances, n_dimensions,)`
:param X: array-like of shape `(n_instances, n_dimensions,)`
:return: np.ndarray of shape `(n_instances, n_classes,)` with posterior probabilities
"""
posteriors = self.classifier.predict_proba(instances)
posteriors = self.classifier.predict_proba(X)
if hasattr(self, 'calibration_function') and self.calibration_function is not None:
posteriors = self.calibration_function(posteriors)
return posteriors
def aggregation_fit(self, classif_predictions: LabelledCollection, data: LabelledCollection):
def aggregation_fit(self, classif_predictions):
"""
Trains the aggregation function of EMQ. This comes down to recalibrating the posterior probabilities
ir requested.
:param classif_predictions: a :class:`quapy.data.base.LabelledCollection` containing,
as instances, the posterior probabilities issued by the classifier and, as labels, the true labels
:param data: a :class:`quapy.data.base.LabelledCollection` consisting of the training data
"""
P, y = classif_predictions
n_classes = len(self.classes_)
if self.recalib is not None:
P, y = classif_predictions.Xy
if self.recalib == 'nbvs':
calibrator = NoBiasVectorScaling()
elif self.recalib == 'bcts':
@ -735,11 +753,11 @@ class EMQ(AggregativeSoftQuantifier):
'"nbvs", "bcts", "ts", and "vs".')
if not np.issubdtype(y.dtype, np.number):
y = np.searchsorted(data.classes_, y)
self.calibration_function = calibrator(P, np.eye(data.n_classes)[y], posterior_supplied=True)
y = np.searchsorted(self.classes_, y)
self.calibration_function = calibrator(P, np.eye(n_classes)[y], posterior_supplied=True)
if self.exact_train_prev:
self.train_prevalence = data.prevalence()
self.train_prevalence = F.prevalence_from_labels(y, self.classes_)
else:
train_posteriors = classif_predictions.X
if self.recalib is not None:
@ -806,6 +824,101 @@ class EMQ(AggregativeSoftQuantifier):
return qs, ps
class BayesianCC(AggregativeCrispQuantifier):
"""
`Bayesian quantification <https://arxiv.org/abs/2302.09159>`_ method,
which is a variant of :class:`ACC` that calculates the posterior probability distribution
over the prevalence vectors, rather than providing a point estimate obtained
by matrix inversion.
Can be used to diagnose degeneracy in the predictions visible when the confusion
matrix has high condition number or to quantify uncertainty around the point estimate.
This method relies on extra dependencies, which have to be installed via:
`$ pip install quapy[bayes]`
:param classifier: a sklearn's Estimator that generates a classifier
:param val_split: a float in (0, 1) indicating the proportion of the training data to be used,
as a stratified held-out validation set, for generating classifier predictions.
:param num_warmup: number of warmup iterations for the MCMC sampler (default 500)
:param num_samples: number of samples to draw from the posterior (default 1000)
:param mcmc_seed: random seed for the MCMC sampler (default 0)
"""
def __init__(self,
classifier: BaseEstimator = None,
val_split: float = 0.75,
num_warmup: int = 500,
num_samples: int = 1_000,
mcmc_seed: int = 0):
if num_warmup <= 0:
raise ValueError(f'parameter {num_warmup=} must be a positive integer')
if num_samples <= 0:
raise ValueError(f'parameter {num_samples=} must be a positive integer')
if (not isinstance(val_split, float)) or val_split <= 0 or val_split >= 1:
raise ValueError(f'val_split must be a float in (0, 1), got {val_split}')
if _bayesian.DEPENDENCIES_INSTALLED is False:
raise ImportError("Auxiliary dependencies are required. Run `$ pip install quapy[bayes]` to install them.")
self.classifier = qp._get_classifier(classifier)
self.val_split = val_split
self.num_warmup = num_warmup
self.num_samples = num_samples
self.mcmc_seed = mcmc_seed
# Array of shape (n_classes, n_predicted_classes,) where entry (y, c) is the number of instances
# labeled as class y and predicted as class c.
# By default, this array is set to None and later defined as part of the `aggregation_fit` phase
self._n_and_c_labeled = None
# Dictionary with posterior samples, set when `aggregate` is provided.
self._samples = None
def aggregation_fit(self, classif_predictions: LabelledCollection, data: LabelledCollection):
"""
Estimates the misclassification rates.
:param classif_predictions: a :class:`quapy.data.base.LabelledCollection` containing,
as instances, the label predictions issued by the classifier and, as labels, the true labels
:param data: a :class:`quapy.data.base.LabelledCollection` consisting of the training data
"""
pred_labels, true_labels = classif_predictions.Xy
self._n_and_c_labeled = confusion_matrix(y_true=true_labels, y_pred=pred_labels,
labels=self.classifier.classes_)
def sample_from_posterior(self, classif_predictions):
if self._n_and_c_labeled is None:
raise ValueError("aggregation_fit must be called before sample_from_posterior")
n_c_unlabeled = F.counts_from_labels(classif_predictions, self.classifier.classes_)
self._samples = _bayesian.sample_posterior(
n_c_unlabeled=n_c_unlabeled,
n_y_and_c_labeled=self._n_and_c_labeled,
num_warmup=self.num_warmup,
num_samples=self.num_samples,
seed=self.mcmc_seed,
)
return self._samples
def get_prevalence_samples(self):
if self._samples is None:
raise ValueError("sample_from_posterior must be called before get_prevalence_samples")
return self._samples[_bayesian.P_TEST_Y]
def get_conditional_probability_samples(self):
if self._samples is None:
raise ValueError("sample_from_posterior must be called before get_conditional_probability_samples")
return self._samples[_bayesian.P_C_COND_Y]
def aggregate(self, classif_predictions):
samples = self.sample_from_posterior(classif_predictions)[_bayesian.P_TEST_Y]
return np.asarray(samples.mean(axis=0), dtype=float)
class HDy(AggregativeSoftQuantifier, BinaryAggregativeQuantifier):
"""
`Hellinger Distance y <https://www.sciencedirect.com/science/article/pii/S0020025512004069>`_ (HDy).
@ -897,7 +1010,8 @@ class DyS(AggregativeSoftQuantifier, BinaryAggregativeQuantifier):
:param n_jobs: number of parallel workers.
"""
def __init__(self, classifier: BaseEstimator=None, val_split=5, n_bins=8, divergence: Union[str, Callable]= 'HD', tol=1e-05, n_jobs=None):
def __init__(self, classifier: BaseEstimator = None, val_split=5, n_bins=8, divergence: Union[str, Callable] = 'HD',
tol=1e-05, n_jobs=None):
self.classifier = qp._get_classifier(classifier)
self.val_split = val_split
self.tol = tol
@ -1083,6 +1197,7 @@ class DMy(AggregativeSoftQuantifier):
test_distribution = self._get_distributions(posteriors)
divergence = get_divergence(self.divergence)
n_classes, n_channels, nbins = self.validation_distribution.shape
def loss(prev):
prev = np.expand_dims(prev, axis=0)
mixture_distribution = (prev @ self.validation_distribution.reshape(n_classes, -1)).reshape(n_channels, -1)
@ -1092,7 +1207,6 @@ class DMy(AggregativeSoftQuantifier):
return F.argmin_prevalence(loss, n_classes, method=self.search)
def newELM(svmperf_base=None, loss='01', C=1):
"""
Explicit Loss Minimization (ELM) quantifiers.
@ -1145,6 +1259,7 @@ def newSVMQ(svmperf_base=None, C=1):
"""
return newELM(svmperf_base, loss='q', C=C)
def newSVMKLD(svmperf_base=None, C=1):
"""
SVM(KLD) is an Explicit Loss Minimization (ELM) quantifier set to optimize for the Kullback-Leibler Divergence
@ -1195,6 +1310,7 @@ def newSVMKLD(svmperf_base=None, C=1):
"""
return newELM(svmperf_base, loss='nkld', C=C)
def newSVMAE(svmperf_base=None, C=1):
"""
SVM(KLD) is an Explicit Loss Minimization (ELM) quantifier set to optimize for the Absolute Error as first used by
@ -1219,6 +1335,7 @@ def newSVMAE(svmperf_base=None, C=1):
"""
return newELM(svmperf_base, loss='mae', C=C)
def newSVMRAE(svmperf_base=None, C=1):
"""
SVM(KLD) is an Explicit Loss Minimization (ELM) quantifier set to optimize for the Relative Absolute Error as first
@ -1269,7 +1386,7 @@ class OneVsAllAggregative(OneVsAllGeneric, AggregativeQuantifier):
self.n_jobs = qp._get_njobs(n_jobs)
self.parallel_backend = parallel_backend
def classify(self, instances):
def classify(self, X):
"""
If the base quantifier is not probabilistic, returns a matrix of shape `(n,m,)` with `n` the number of
instances and `m` the number of classes. The entry `(i,j)` is a binary value indicating whether instance
@ -1280,11 +1397,11 @@ class OneVsAllAggregative(OneVsAllGeneric, AggregativeQuantifier):
posterior probability that instance `i` belongs (resp. does not belong) to class `j`. The posterior
probabilities are independent of each other, meaning that, in general, they do not sum up to one.
:param instances: array-like
:param X: array-like
:return: `np.ndarray`
"""
classif_predictions = self._parallel(self._delayed_binary_classification, instances)
classif_predictions = self._parallel(self._delayed_binary_classification, X)
if isinstance(self.binary_quantifier, AggregativeSoftQuantifier):
return np.swapaxes(classif_predictions, 0, 1)
else:
@ -1314,6 +1431,7 @@ class AggregativeMedianEstimator(BinaryQuantifier):
:param param_grid: the grid or parameters towards which the median will be computed
:param n_jobs: number of parllel workes
"""
def __init__(self, base_quantifier: AggregativeQuantifier, param_grid: dict, random_state=None, n_jobs=None):
self.base_quantifier = base_quantifier
self.param_grid = param_grid
@ -1350,7 +1468,6 @@ class AggregativeMedianEstimator(BinaryQuantifier):
model.aggregation_fit(predictions, training)
return model
def fit(self, training: LabelledCollection, **kwargs):
import itertools
@ -1419,7 +1536,6 @@ X = _threshold_optim.X
MS = _threshold_optim.MS
MS2 = _threshold_optim.MS2
from . import _kdey
KDEyML = _kdey.KDEyML

View File

@ -14,26 +14,36 @@ import numpy as np
class BaseQuantifier(BaseEstimator):
"""
Abstract Quantifier. A quantifier is defined as an object of a class that implements the method :meth:`fit` on
:class:`quapy.data.base.LabelledCollection`, the method :meth:`quantify`, and the :meth:`set_params` and
a pair X, y, the method :meth:`predict`, and the :meth:`set_params` and
:meth:`get_params` for model selection (see :meth:`quapy.model_selection.GridSearchQ`)
"""
@abstractmethod
def fit(self, data: LabelledCollection):
def fit(self, X, y):
"""
Trains a quantifier.
Generates a quantifier.
:param data: a :class:`quapy.data.base.LabelledCollection` consisting of the training data
:param X: array-like, the training instances
:param y: array-like, the labels
:return: self
"""
...
@abstractmethod
def quantify(self, instances):
def predict(self, X):
"""
Generate class prevalence estimates for the sample's instances
:param instances: array-like
:param X: array-like, the test instances
:return: `np.ndarray` of shape `(n_classes,)` with class prevalence estimates.
"""
...
def quantify(self, X):
"""
Alias to :meth:`predict`, for old compatibility
:param X: array-like
:return: `np.ndarray` of shape `(n_classes,)` with class prevalence estimates.
"""
...
@ -45,8 +55,9 @@ class BinaryQuantifier(BaseQuantifier):
(typically, to be interpreted as one class and its complement).
"""
def _check_binary(self, data: LabelledCollection, quantifier_name):
assert data.binary, f'{quantifier_name} works only on problems of binary classification. ' \
def _check_binary(self, y, quantifier_name):
n_classes = len(set(y))
assert n_classes==2, f'{quantifier_name} works only on problems of binary classification. ' \
f'Use the class OneVsAll to enable {quantifier_name} work on single-label data.'
@ -66,7 +77,7 @@ def newOneVsAll(binary_quantifier: BaseQuantifier, n_jobs=None):
class OneVsAllGeneric(OneVsAll, BaseQuantifier):
"""
Allows any binary quantifier to perform quantification on single-label datasets. The method maintains one binary
quantifier for each class, and then l1-normalizes the outputs so that the class prevelence values sum up to 1.
quantifier for each class, and then l1-normalizes the outputs so that the class prevalence values sum up to 1.
"""
def __init__(self, binary_quantifier: BaseQuantifier, n_jobs=None):
@ -93,8 +104,8 @@ class OneVsAllGeneric(OneVsAll, BaseQuantifier):
)
)
def quantify(self, instances):
prevalences = self._parallel(self._delayed_binary_predict, instances)
def predict(self, X):
prevalences = self._parallel(self._delayed_binary_predict, X)
return qp.functional.normalize_prevalence(prevalences)
@property
@ -102,7 +113,7 @@ class OneVsAllGeneric(OneVsAll, BaseQuantifier):
return sorted(self.dict_binary_quantifiers.keys())
def _delayed_binary_predict(self, c, X):
return self.dict_binary_quantifiers[c].quantify(X)[1]
return self.dict_binary_quantifiers[c].predict(X)[1]
def _delayed_binary_fit(self, c, data):
bindata = LabelledCollection(data.instances, data.labels == c, classes=[False, True])

View File

@ -72,12 +72,12 @@ class MedianEstimator2(BinaryQuantifier):
def _delayed_predict(self, args):
model, instances = args
return model.quantify(instances)
return model.predict(instances)
def quantify(self, instances):
def predict(self, X):
prev_preds = qp.util.parallel(
self._delayed_predict,
((model, instances) for model in self.models),
((model, X) for model in self.models),
seed=qp.environ.get('_R_SEED', None),
n_jobs=self.n_jobs
)
@ -174,12 +174,12 @@ class MedianEstimator(BinaryQuantifier):
def _delayed_predict(self, args):
model, instances = args
return model.quantify(instances)
return model.predict(instances)
def quantify(self, instances):
def predict(self, X):
prev_preds = qp.util.parallel(
self._delayed_predict,
((model, instances) for model in self.models),
((model, X) for model in self.models),
seed=qp.environ.get('_R_SEED', None),
n_jobs=self.n_jobs,
asarray=False
@ -294,15 +294,15 @@ class Ensemble(BaseQuantifier):
self._sout('Fit [Done]')
return self
def quantify(self, instances):
def predict(self, X):
predictions = np.asarray(
qp.util.parallel(_delayed_quantify, ((Qi, instances) for Qi in self.ensemble), n_jobs=self.n_jobs)
qp.util.parallel(_delayed_quantify, ((Qi, X) for Qi in self.ensemble), n_jobs=self.n_jobs)
)
if self.policy == 'ptr':
predictions = self._ptr_policy(predictions)
elif self.policy == 'ds':
predictions = self._ds_policy(predictions, instances)
predictions = self._ds_policy(predictions, X)
predictions = np.mean(predictions, axis=0)
return F.normalize_prevalence(predictions)
@ -470,7 +470,7 @@ def _delayed_new_instance(args):
def _delayed_quantify(args):
quantifier, instances = args
return quantifier[0].quantify(instances)
return quantifier[0].predict(instances)
def _draw_simplex(ndim, min_val, max_trials=100):

View File

@ -30,11 +30,11 @@ class MaximumLikelihoodPrevalenceEstimation(BaseQuantifier):
self.estimated_prevalence = data.prevalence()
return self
def quantify(self, instances):
def predict(self, X):
"""
Ignores the input instances and returns, as the class prevalence estimantes, the training prevalence.
:param instances: array-like (ignored)
:param X: array-like (ignored)
:return: the class prevalence seen during training
"""
return self.estimated_prevalence
@ -122,20 +122,20 @@ class DMx(BaseQuantifier):
return self
def quantify(self, instances):
def predict(self, X):
"""
Searches for the mixture model parameter (the sought prevalence values) that yields a validation distribution
(the mixture) that best matches the test distribution, in terms of the divergence measure of choice.
The matching is computed as the average dissimilarity (in terms of the dissimilarity measure of choice)
between all feature-specific discrete distributions.
:param instances: instances in the sample
:param X: instances in the sample
:return: a vector of class prevalence estimates
"""
assert instances.shape[1] == self.nfeats, f'wrong shape; expected {self.nfeats}, found {instances.shape[1]}'
assert X.shape[1] == self.nfeats, f'wrong shape; expected {self.nfeats}, found {X.shape[1]}'
test_distribution = self.__get_distributions(instances)
test_distribution = self.__get_distributions(X)
divergence = get_divergence(self.divergence)
n_classes, n_feats, nbins = self.validation_distribution.shape
def loss(prev):
@ -163,8 +163,8 @@ class ReadMe(BaseQuantifier):
X = self.vectorizer.fit_transform(X)
self.class_conditional_X = {i: X[y==i] for i in range(data.classes_)}
def quantify(self, instances):
X = self.vectorizer.transform(instances)
def predict(self, X):
X = self.vectorizer.transform(X)
# number of features
num_docs, num_feats = X.shape

View File

@ -275,15 +275,15 @@ class GridSearchQ(BaseQuantifier):
return self
def quantify(self, instances):
def predict(self, X):
"""Estimate class prevalence values using the best model found after calling the :meth:`fit` method.
:param instances: sample contanining the instances
:param X: sample contanining the instances
:return: a ndarray of shape `(n_classes)` with class prevalence estimates as according to the best model found
by the model selection process.
"""
assert hasattr(self, 'best_model_'), 'quantify called before fit'
return self.best_model().quantify(instances)
return self.best_model().predict(X)
def set_params(self, **parameters):
"""Sets the hyper-parameters to explore.
@ -365,7 +365,7 @@ def cross_val_predict(quantifier: BaseQuantifier, data: LabelledCollection, nfol
for train, test in data.kFCV(nfolds=nfolds, random_state=random_state):
quantifier.fit(train)
fold_prev = quantifier.quantify(test.X)
fold_prev = quantifier.predict(test.X)
rel_size = 1. * len(test) / len(data)
total_prev += fold_prev*rel_size

View File

@ -18,7 +18,7 @@ class TestDatasets(unittest.TestCase):
q = self.new_quantifier()
print(f'testing method {q} in {dataset.name}...', end='')
q.fit(dataset.training)
estim_prevalences = q.quantify(dataset.test.instances)
estim_prevalences = q.predict(dataset.test.instances)
self.assertTrue(F.check_prevalence_vector(estim_prevalences))
print(f'[done]')
@ -26,7 +26,7 @@ class TestDatasets(unittest.TestCase):
for X, p in gen():
if vectorizer is not None:
X = vectorizer.transform(X)
estim_prevalences = q.quantify(X)
estim_prevalences = q.predict(X)
self.assertTrue(F.check_prevalence_vector(estim_prevalences))
max_samples_test -= 1
if max_samples_test == 0:

View File

@ -41,8 +41,8 @@ class EvalTestCase(unittest.TestCase):
def __init__(self, cls):
self.emq = EMQ(cls)
def quantify(self, instances):
return self.emq.quantify(instances)
def predict(self, X):
return self.emq.predict(X)
def fit(self, data):
self.emq.fit(data)

View File

@ -51,7 +51,7 @@ class TestMethods(unittest.TestCase):
q = model(learner)
print('testing', q)
q.fit(dataset.training, fit_classifier=False)
estim_prevalences = q.quantify(dataset.test.X)
estim_prevalences = q.predict(dataset.test.X)
self.assertTrue(check_prevalence_vector(estim_prevalences))
def test_non_aggregative(self):
@ -65,7 +65,7 @@ class TestMethods(unittest.TestCase):
q = model()
print(f'testing {q} on dataset {dataset.name}')
q.fit(dataset.training)
estim_prevalences = q.quantify(dataset.test.X)
estim_prevalences = q.predict(dataset.test.X)
self.assertTrue(check_prevalence_vector(estim_prevalences))
def test_ensembles(self):
@ -81,7 +81,7 @@ class TestMethods(unittest.TestCase):
print(f'testing {base_quantifier} on dataset {dataset.name} with {policy=}')
ensemble = Ensemble(quantifier=base_quantifier, size=3, policy=policy, n_jobs=-1)
ensemble.fit(dataset.training)
estim_prevalences = ensemble.quantify(dataset.test.instances)
estim_prevalences = ensemble.predict(dataset.test.instances)
self.assertTrue(check_prevalence_vector(estim_prevalences))
def test_quanet(self):
@ -107,7 +107,7 @@ class TestMethods(unittest.TestCase):
model = QuaNet(learner, device='cpu', n_epochs=2, tr_iter_per_poch=10, va_iter_per_poch=10, patience=2)
model.fit(dataset.training)
estim_prevalences = model.quantify(dataset.test.instances)
estim_prevalences = model.predict(dataset.test.instances)
self.assertTrue(check_prevalence_vector(estim_prevalences))
def test_composable(self):
@ -115,7 +115,7 @@ class TestMethods(unittest.TestCase):
for q in COMPOSABLE_METHODS:
print('testing', q)
q.fit(dataset.training)
estim_prevalences = q.quantify(dataset.test.X)
estim_prevalences = q.predict(dataset.test.X)
self.assertTrue(check_prevalence_vector(estim_prevalences))

View File

@ -17,13 +17,13 @@ class TestReplicability(unittest.TestCase):
with qp.util.temp_seed(0):
lr = LogisticRegression(random_state=0, max_iter=10000)
pacc = PACC(lr)
prev = pacc.fit(dataset.training).quantify(dataset.test.X)
prev = pacc.fit(dataset.training).predict(dataset.test.X)
str_prev1 = strprev(prev, prec=5)
with qp.util.temp_seed(0):
lr = LogisticRegression(random_state=0, max_iter=10000)
pacc = PACC(lr)
prev2 = pacc.fit(dataset.training).quantify(dataset.test.X)
prev2 = pacc.fit(dataset.training).predict(dataset.test.X)
str_prev2 = strprev(prev2, prec=5)
self.assertEqual(str_prev1, str_prev2)
@ -85,17 +85,17 @@ class TestReplicability(unittest.TestCase):
with qp.util.temp_seed(10):
pacc = PACC(LogisticRegression(), val_split=2, n_jobs=2)
pacc.fit(train, val_split=0.5)
prev1 = F.strprev(pacc.quantify(test.instances))
prev1 = F.strprev(pacc.predict(test.instances))
with qp.util.temp_seed(0):
pacc = PACC(LogisticRegression(), val_split=2, n_jobs=2)
pacc.fit(train, val_split=0.5)
prev2 = F.strprev(pacc.quantify(test.instances))
prev2 = F.strprev(pacc.predict(test.instances))
with qp.util.temp_seed(0):
pacc = PACC(LogisticRegression(), val_split=2, n_jobs=2)
pacc.fit(train, val_split=0.5)
prev3 = F.strprev(pacc.quantify(test.instances))
prev3 = F.strprev(pacc.predict(test.instances))
print(prev1)
print(prev2)

16
testing_refactor.py Normal file
View File

@ -0,0 +1,16 @@
import quapy as qp
from method.aggregative import *
datasets = qp.datasets.UCI_MULTICLASS_DATASETS[1]
data = qp.datasets.fetch_UCIMulticlassDataset(datasets)
train, test = data.train_test
quant = EMQ()
quant.fit(*train.Xy)
prev = quant.predict(test.X)
print(prev)
# test CC, prevent from doing 5FCV for nothing
# test PACC o PCC with LinearSVC; removing "adapt_if_necessary" form _check_classifier