forked from moreo/QuaPy
fango
This commit is contained in:
parent
173db83c28
commit
e870d798b7
|
@ -2,7 +2,7 @@ import quapy as qp
|
||||||
from quapy.data import LabelledCollection
|
from quapy.data import LabelledCollection
|
||||||
from quapy.method.base import BinaryQuantifier
|
from quapy.method.base import BinaryQuantifier
|
||||||
from quapy.model_selection import GridSearchQ
|
from quapy.model_selection import GridSearchQ
|
||||||
from quapy.method.aggregative import AggregativeProbabilisticQuantifier
|
from quapy.method.aggregative import AggregativeSoftQuantifier
|
||||||
from quapy.protocol import APP
|
from quapy.protocol import APP
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from sklearn.linear_model import LogisticRegression
|
from sklearn.linear_model import LogisticRegression
|
||||||
|
@ -15,7 +15,7 @@ from sklearn.linear_model import LogisticRegression
|
||||||
# internal hyperparameter (let say, alpha) which is the decision threshold. Let's also assume the quantifier
|
# internal hyperparameter (let say, alpha) which is the decision threshold. Let's also assume the quantifier
|
||||||
# is binary, for simplicity.
|
# is binary, for simplicity.
|
||||||
|
|
||||||
class MyQuantifier(AggregativeProbabilisticQuantifier, BinaryQuantifier):
|
class MyQuantifier(AggregativeSoftQuantifier, BinaryQuantifier):
|
||||||
def __init__(self, classifier, alpha=0.5):
|
def __init__(self, classifier, alpha=0.5):
|
||||||
self.alpha = alpha
|
self.alpha = alpha
|
||||||
# aggregative quantifiers have an internal self.classifier attribute
|
# aggregative quantifiers have an internal self.classifier attribute
|
||||||
|
|
|
@ -24,7 +24,8 @@ class RecalibratedProbabilisticClassifier:
|
||||||
class RecalibratedProbabilisticClassifierBase(BaseEstimator, RecalibratedProbabilisticClassifier):
|
class RecalibratedProbabilisticClassifierBase(BaseEstimator, RecalibratedProbabilisticClassifier):
|
||||||
"""
|
"""
|
||||||
Applies a (re)calibration method from `abstention.calibration`, as defined in
|
Applies a (re)calibration method from `abstention.calibration`, as defined in
|
||||||
`Alexandari et al. paper <http://proceedings.mlr.press/v119/alexandari20a.html>`_:
|
`Alexandari et al. paper <http://proceedings.mlr.press/v119/alexandari20a.html>`_.
|
||||||
|
|
||||||
|
|
||||||
:param classifier: a scikit-learn probabilistic classifier
|
:param classifier: a scikit-learn probabilistic classifier
|
||||||
:param calibrator: the calibration object (an instance of abstention.calibration.CalibratorFactory)
|
:param calibrator: the calibration object (an instance of abstention.calibration.CalibratorFactory)
|
||||||
|
|
|
@ -2,6 +2,7 @@ from abc import ABC, abstractmethod
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from typing import Callable, Union
|
from typing import Callable, Union
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from abstention.calibration import NoBiasVectorScaling, TempScaling, VectorScaling
|
||||||
from scipy import optimize
|
from scipy import optimize
|
||||||
from sklearn.base import BaseEstimator
|
from sklearn.base import BaseEstimator
|
||||||
from sklearn.calibration import CalibratedClassifierCV
|
from sklearn.calibration import CalibratedClassifierCV
|
||||||
|
@ -46,7 +47,7 @@ class AggregativeQuantifier(BaseQuantifier, ABC):
|
||||||
:return: self
|
:return: self
|
||||||
"""
|
"""
|
||||||
classif_predictions = self.classifier_fit_predict(data, fit_classifier)
|
classif_predictions = self.classifier_fit_predict(data, fit_classifier)
|
||||||
self.aggregation_fit(classif_predictions)
|
self.aggregation_fit(classif_predictions, data)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def classifier_fit_predict(self, data: LabelledCollection, fit_classifier=True, predict_on=None):
|
def classifier_fit_predict(self, data: LabelledCollection, fit_classifier=True, predict_on=None):
|
||||||
|
@ -66,7 +67,6 @@ class AggregativeQuantifier(BaseQuantifier, ABC):
|
||||||
"""
|
"""
|
||||||
assert isinstance(fit_classifier, bool), 'unexpected type for "fit_classifier", must be boolean'
|
assert isinstance(fit_classifier, bool), 'unexpected type for "fit_classifier", must be boolean'
|
||||||
|
|
||||||
print(type(self))
|
|
||||||
self._check_classifier(adapt_if_necessary=(self._classifier_method() == 'predict_proba'))
|
self._check_classifier(adapt_if_necessary=(self._classifier_method() == 'predict_proba'))
|
||||||
|
|
||||||
if predict_on is None:
|
if predict_on is None:
|
||||||
|
@ -80,7 +80,7 @@ class AggregativeQuantifier(BaseQuantifier, ABC):
|
||||||
raise ValueError(f'proportion {predict_on=} out of range, must be in (0,1)')
|
raise ValueError(f'proportion {predict_on=} out of range, must be in (0,1)')
|
||||||
train, val = data.split_stratified(train_prop=(1 - predict_on))
|
train, val = data.split_stratified(train_prop=(1 - predict_on))
|
||||||
self.classifier.fit(*train.Xy)
|
self.classifier.fit(*train.Xy)
|
||||||
predictions = (self.classify(val.X), val.y)
|
predictions = LabelledCollection(self.classify(val.X), val.y, classes=data.classes_)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f'wrong type for predict_on: since fit_classifier=False, '
|
raise ValueError(f'wrong type for predict_on: since fit_classifier=False, '
|
||||||
f'the set on which predictions have to be issued must be '
|
f'the set on which predictions have to be issued must be '
|
||||||
|
@ -89,15 +89,17 @@ class AggregativeQuantifier(BaseQuantifier, ABC):
|
||||||
elif isinstance(predict_on, LabelledCollection):
|
elif isinstance(predict_on, LabelledCollection):
|
||||||
if fit_classifier:
|
if fit_classifier:
|
||||||
self.classifier.fit(*data.Xy)
|
self.classifier.fit(*data.Xy)
|
||||||
predictions = (self.classify(predict_on.X), predict_on.y)
|
predictions = LabelledCollection(self.classify(predict_on.X), predict_on.y, classes=predict_on.classes_)
|
||||||
|
|
||||||
elif isinstance(predict_on, int):
|
elif isinstance(predict_on, int):
|
||||||
if fit_classifier:
|
if fit_classifier:
|
||||||
if not predict_on > 1:
|
if predict_on <= 1:
|
||||||
raise ValueError(f'invalid value {predict_on} in fit. '
|
raise ValueError(f'invalid value {predict_on} in fit. '
|
||||||
f'Specify a integer >1 for kFCV estimation.')
|
f'Specify a integer >1 for kFCV estimation.')
|
||||||
|
else:
|
||||||
predictions = cross_val_predict(
|
predictions = cross_val_predict(
|
||||||
classifier, *data.Xy, cv=predict_on, n_jobs=self.n_jobs, method=self._classifier_method())
|
self.classifier, *data.Xy, cv=predict_on, n_jobs=self.n_jobs, method=self._classifier_method())
|
||||||
|
predictions = LabelledCollection(predictions, data.y, classes=data.classes_)
|
||||||
self.classifier.fit(*data.Xy)
|
self.classifier.fit(*data.Xy)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f'wrong type for predict_on: since fit_classifier=False, '
|
raise ValueError(f'wrong type for predict_on: since fit_classifier=False, '
|
||||||
|
@ -113,12 +115,13 @@ class AggregativeQuantifier(BaseQuantifier, ABC):
|
||||||
return predictions
|
return predictions
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def aggregation_fit(self, classif_predictions: LabelledCollection):
|
def aggregation_fit(self, classif_predictions: LabelledCollection, data: LabelledCollection):
|
||||||
"""
|
"""
|
||||||
Trains the aggregation function.
|
Trains the aggregation function.
|
||||||
|
|
||||||
:param classif_predictions: typically an `ndarray` containing the label predictions, but could be a
|
:param classif_predictions: a LabelledCollection containing the label predictions issued
|
||||||
tuple containing any information needed for fitting the aggregation function
|
by the classifier
|
||||||
|
:param data: a :class:`quapy.data.base.LabelledCollection` consisting of the training data
|
||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
||||||
|
@ -140,23 +143,36 @@ class AggregativeQuantifier(BaseQuantifier, ABC):
|
||||||
"""
|
"""
|
||||||
self.classifier_ = classifier
|
self.classifier_ = classifier
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
def classify(self, instances):
|
def classify(self, instances):
|
||||||
"""
|
"""
|
||||||
Provides the label predictions for the given instances. The predictions should respect the format expected by
|
Provides the label predictions for the given instances. The predictions should respect the format expected by
|
||||||
:meth:`aggregate`, i.e., posterior probabilities for probabilistic quantifiers, or crisp predictions for
|
:meth:`aggregate`, e.g., posterior probabilities for probabilistic quantifiers, or crisp predictions for
|
||||||
non-probabilistic quantifiers
|
non-probabilistic quantifiers
|
||||||
|
|
||||||
:param instances: array-like
|
:param instances: array-like of shape `(n_instances, n_features,)`
|
||||||
:return: np.ndarray of shape `(n_instances,)` with label predictions
|
:return: np.ndarray of shape `(n_instances,)` with label predictions
|
||||||
"""
|
"""
|
||||||
return self.classifier.predict(instances)
|
...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
def _classifier_method(self):
|
def _classifier_method(self):
|
||||||
print('using predict')
|
"""
|
||||||
return 'predict'
|
Name of the method that must be used for issuing label predictions.
|
||||||
|
|
||||||
|
:return: string
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
def _check_classifier(self, adapt_if_necessary=False):
|
def _check_classifier(self, adapt_if_necessary=False):
|
||||||
assert hasattr(self.classifier, self._classifier_method())
|
"""
|
||||||
|
Guarantees that the underlying classifier implements the method required for issuing predictions, i.e.,
|
||||||
|
the method indicated by the :meth:`_classifier_method`
|
||||||
|
|
||||||
|
:param adapt_if_necessary: if True, the method will try to comply with the required specifications
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
def quantify(self, instances):
|
def quantify(self, instances):
|
||||||
"""
|
"""
|
||||||
|
@ -190,22 +206,77 @@ class AggregativeQuantifier(BaseQuantifier, ABC):
|
||||||
return self.classifier.classes_
|
return self.classifier.classes_
|
||||||
|
|
||||||
|
|
||||||
class AggregativeProbabilisticQuantifier(AggregativeQuantifier, ABC):
|
class AggregativeCrispQuantifier(AggregativeQuantifier, ABC):
|
||||||
"""
|
"""
|
||||||
Abstract class for quantification methods that base their estimations on the aggregation of posterior probabilities
|
Abstract class for quantification methods that base their estimations on the aggregation of crips decisions
|
||||||
as returned by a probabilistic classifier. Aggregative Probabilistic Quantifiers thus extend Aggregative
|
as returned by a hard classifier. Aggregative crisp quantifiers thus extend Aggregative
|
||||||
Quantifiers by implementing a _posterior_probabilities_ method returning values in [0,1] -- the posterior
|
Quantifiers by implementing specifications about crisp predictions.
|
||||||
probabilities.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def classify(self, instances):
|
def classify(self, instances):
|
||||||
|
"""
|
||||||
|
Provides the label (crisp) predictions for the given instances.
|
||||||
|
|
||||||
|
:param instances: array-like of shape `(n_instances, n_dimensions,)`
|
||||||
|
:return: np.ndarray of shape `(n_instances,)` with label predictions
|
||||||
|
"""
|
||||||
|
return self.classifier.predict(instances)
|
||||||
|
|
||||||
|
def _classifier_method(self):
|
||||||
|
"""
|
||||||
|
Name of the method that must be used for issuing label predictions.
|
||||||
|
|
||||||
|
:return: the string "predict", i.e., the standard method name for scikit-learn hard predictions
|
||||||
|
"""
|
||||||
|
print('using predict')
|
||||||
|
return 'predict'
|
||||||
|
|
||||||
|
def _check_classifier(self, adapt_if_necessary=False):
|
||||||
|
"""
|
||||||
|
Guarantees that the underlying classifier implements the method indicated by the :meth:`_classifier_method`
|
||||||
|
|
||||||
|
:param adapt_if_necessary: unused, added for compatibility
|
||||||
|
"""
|
||||||
|
assert hasattr(self.classifier, self._classifier_method()), \
|
||||||
|
f"the method does not implement the required {self._classifier_method()} method"
|
||||||
|
|
||||||
|
|
||||||
|
class AggregativeSoftQuantifier(AggregativeQuantifier, ABC):
|
||||||
|
"""
|
||||||
|
Abstract class for quantification methods that base their estimations on the aggregation of posterior
|
||||||
|
probabilities as returned by a probabilistic classifier.
|
||||||
|
Aggregative soft quantifiers thus extend Aggregative Quantifiers by implementing specifications
|
||||||
|
about soft predictions.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def classify(self, instances):
|
||||||
|
"""
|
||||||
|
Provides the posterior probabilities for the given instances.
|
||||||
|
|
||||||
|
:param instances: array-like of shape `(n_instances, n_dimensions,)`
|
||||||
|
:return: np.ndarray of shape `(n_instances, n_classes,)` with posterior probabilities
|
||||||
|
"""
|
||||||
return self.classifier.predict_proba(instances)
|
return self.classifier.predict_proba(instances)
|
||||||
|
|
||||||
def _classifier_method(self):
|
def _classifier_method(self):
|
||||||
|
"""
|
||||||
|
Name of the method that must be used for issuing label predictions.
|
||||||
|
|
||||||
|
:return: the string "predict_proba", i.e., the standard method name for scikit-learn soft predictions
|
||||||
|
"""
|
||||||
print('using predict_proba')
|
print('using predict_proba')
|
||||||
return 'predict_proba'
|
return 'predict_proba'
|
||||||
|
|
||||||
def _check_classifier(self, adapt_if_necessary=False):
|
def _check_classifier(self, adapt_if_necessary=False):
|
||||||
|
"""
|
||||||
|
Guarantees that the underlying classifier implements the method indicated by the :meth:`_classifier_method`.
|
||||||
|
In case it does not, the classifier is calibrated (by means of the Platt's calibration method implemented by
|
||||||
|
scikit-learn in CalibratedClassifierCV, with cv=5). This calibration is only allowed if `adapt_if_necessary`
|
||||||
|
is set to True. If otherwise (i.e., the classifier is not probabilistic, and `adapt_if_necessary` is set
|
||||||
|
to False), an exception will be raised.
|
||||||
|
|
||||||
|
:param adapt_if_necessary: a hard classifier is turned into a soft classifier if `adapt_if_necessary==True`
|
||||||
|
"""
|
||||||
if not hasattr(self.classifier, self._classifier_method()):
|
if not hasattr(self.classifier, self._classifier_method()):
|
||||||
if adapt_if_necessary:
|
if adapt_if_necessary:
|
||||||
print(f'warning: The learner {self.classifier.__class__.__name__} does not seem to be '
|
print(f'warning: The learner {self.classifier.__class__.__name__} does not seem to be '
|
||||||
|
@ -217,9 +288,42 @@ class AggregativeProbabilisticQuantifier(AggregativeQuantifier, ABC):
|
||||||
f'fit_classifier is set to False')
|
f'fit_classifier is set to False')
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class CorrectionbasedAggregativeQuantifier(AggregativeQuantifier):
|
||||||
|
"""
|
||||||
|
Abstract class for quantification methods that carry out an adjustment (or correction) that requires,
|
||||||
|
at training time, the predictions to be issued in validation mode, i.e., on a set of held-out data that
|
||||||
|
is not the training set. There are three ways in which this distinction can be made, depending on how
|
||||||
|
the internal parameter `val_split` is specified, namely, (i) a float in (0, 1) indicating the proportion
|
||||||
|
of training instances that should be devoted to validate, or (ii) an integer indicating the
|
||||||
|
number of folds to consider in a k-fold cross-validation mode, or (iii) the specific set of data to
|
||||||
|
use for validation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def val_split(self):
|
||||||
|
return self.val_split_
|
||||||
|
|
||||||
|
@val_split.setter
|
||||||
|
def val_split(self, val_split):
|
||||||
|
if isinstance(val_split, LabelledCollection):
|
||||||
|
print('warning: setting val_split with a LabelledCollection will be inefficient in'
|
||||||
|
'model selection. Rather pass the LabelledCollection at fit time')
|
||||||
|
self.val_split_ = val_split
|
||||||
|
|
||||||
|
def fit(self, data: LabelledCollection, fit_classifier=True, predict_on=None):
|
||||||
|
print('method from CorrectionbasedAggregativeQuantifier')
|
||||||
|
if predict_on is None:
|
||||||
|
predict_on = self.val_split
|
||||||
|
classif_predictions = self.classifier_fit_predict(data, fit_classifier, predict_on)
|
||||||
|
self.aggregation_fit(classif_predictions, data)
|
||||||
|
return self
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# Methods
|
# Methods
|
||||||
# ------------------------------------
|
# ------------------------------------
|
||||||
class CC(AggregativeQuantifier):
|
class CC(AggregativeCrispQuantifier):
|
||||||
"""
|
"""
|
||||||
The most basic Quantification method. One that simply classifies all instances and counts how many have been
|
The most basic Quantification method. One that simply classifies all instances and counts how many have been
|
||||||
attributed to each of the classes in order to compute class prevalence estimates.
|
attributed to each of the classes in order to compute class prevalence estimates.
|
||||||
|
@ -230,7 +334,7 @@ class CC(AggregativeQuantifier):
|
||||||
def __init__(self, classifier: BaseEstimator):
|
def __init__(self, classifier: BaseEstimator):
|
||||||
self.classifier = classifier
|
self.classifier = classifier
|
||||||
|
|
||||||
def aggregation_fit(self, classif_predictions: LabelledCollection):
|
def aggregation_fit(self, classif_predictions: LabelledCollection, data: LabelledCollection):
|
||||||
"""
|
"""
|
||||||
Nothing to do here!
|
Nothing to do here!
|
||||||
|
|
||||||
|
@ -248,19 +352,21 @@ class CC(AggregativeQuantifier):
|
||||||
return F.prevalence_from_labels(classif_predictions, self.classes_)
|
return F.prevalence_from_labels(classif_predictions, self.classes_)
|
||||||
|
|
||||||
|
|
||||||
class ACC(AggregativeQuantifier):
|
class ACC(AggregativeCrispQuantifier, CorrectionbasedAggregativeQuantifier):
|
||||||
"""
|
"""
|
||||||
`Adjusted Classify & Count <https://link.springer.com/article/10.1007/s10618-008-0097-y>`_,
|
`Adjusted Classify & Count <https://link.springer.com/article/10.1007/s10618-008-0097-y>`_,
|
||||||
the "adjusted" variant of :class:`CC`, that corrects the predictions of CC
|
the "adjusted" variant of :class:`CC`, that corrects the predictions of CC
|
||||||
according to the `misclassification rates`.
|
according to the `misclassification rates`.
|
||||||
|
|
||||||
:param classifier: a sklearn's Estimator that generates a classifier
|
:param classifier: a sklearn's Estimator that generates a classifier
|
||||||
:param val_split: indicates the proportion of data to be used as a stratified held-out validation set in which the
|
:param val_split: specifies the data used for generating classifier predictions. This specification
|
||||||
misclassification rates are to be estimated.
|
can be made as float in (0, 1) indicating the proportion of stratified held-out validation set to
|
||||||
This parameter can be indicated as a real value (between 0 and 1, default 0.4), representing a proportion of
|
be extracted from the training set (default 0.4); or as an integer, indicating that the predictions
|
||||||
validation data, or as an integer, indicating that the misclassification rates should be estimated via
|
are to be generated in a `k`-fold cross-validation manner (with this integer indicating the value
|
||||||
`k`-fold cross validation (this integer stands for the number of folds `k`), or as a
|
for `k`); or as a collection defining the specific set of data to use for validation.
|
||||||
:class:`quapy.data.base.LabelledCollection` (the split itself).
|
Alternatively, this set can be specified at fit time by indicating the exact set of data
|
||||||
|
on which the predictions are to be generated.
|
||||||
|
:param n_jobs: number of parallel workers
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, classifier: BaseEstimator, val_split=0.4, n_jobs=None):
|
def __init__(self, classifier: BaseEstimator, val_split=0.4, n_jobs=None):
|
||||||
|
@ -268,7 +374,7 @@ class ACC(AggregativeQuantifier):
|
||||||
self.val_split = val_split
|
self.val_split = val_split
|
||||||
self.n_jobs = qp._get_njobs(n_jobs)
|
self.n_jobs = qp._get_njobs(n_jobs)
|
||||||
|
|
||||||
def aggregation_fit(self, classif_predictions: LabelledCollection):
|
def aggregation_fit(self, classif_predictions: LabelledCollection, data: LabelledCollection):
|
||||||
"""
|
"""
|
||||||
Estimates the misclassification rates.
|
Estimates the misclassification rates.
|
||||||
|
|
||||||
|
@ -292,9 +398,6 @@ class ACC(AggregativeQuantifier):
|
||||||
conf[:, i] /= class_counts[i]
|
conf[:, i] /= class_counts[i]
|
||||||
return conf
|
return conf
|
||||||
|
|
||||||
def classify(self, data):
|
|
||||||
return self.cc.classify(data)
|
|
||||||
|
|
||||||
def aggregate(self, classif_predictions):
|
def aggregate(self, classif_predictions):
|
||||||
prevs_estim = self.cc.aggregate(classif_predictions)
|
prevs_estim = self.cc.aggregate(classif_predictions)
|
||||||
return ACC.solve_adjustment(self.Pte_cond_estim_, prevs_estim)
|
return ACC.solve_adjustment(self.Pte_cond_estim_, prevs_estim)
|
||||||
|
@ -321,7 +424,7 @@ class ACC(AggregativeQuantifier):
|
||||||
return adjusted_prevs
|
return adjusted_prevs
|
||||||
|
|
||||||
|
|
||||||
class PCC(AggregativeProbabilisticQuantifier):
|
class PCC(AggregativeSoftQuantifier):
|
||||||
"""
|
"""
|
||||||
`Probabilistic Classify & Count <https://ieeexplore.ieee.org/abstract/document/5694031>`_,
|
`Probabilistic Classify & Count <https://ieeexplore.ieee.org/abstract/document/5694031>`_,
|
||||||
the probabilistic variant of CC that relies on the posterior probabilities returned by a probabilistic classifier.
|
the probabilistic variant of CC that relies on the posterior probabilities returned by a probabilistic classifier.
|
||||||
|
@ -332,7 +435,7 @@ class PCC(AggregativeProbabilisticQuantifier):
|
||||||
def __init__(self, classifier: BaseEstimator):
|
def __init__(self, classifier: BaseEstimator):
|
||||||
self.classifier = classifier
|
self.classifier = classifier
|
||||||
|
|
||||||
def aggregation_fit(self, classif_predictions: LabelledCollection):
|
def aggregation_fit(self, classif_predictions: LabelledCollection, data: LabelledCollection):
|
||||||
"""
|
"""
|
||||||
Nothing to do here!
|
Nothing to do here!
|
||||||
|
|
||||||
|
@ -344,18 +447,18 @@ class PCC(AggregativeProbabilisticQuantifier):
|
||||||
return F.prevalence_from_probabilities(classif_posteriors, binarize=False)
|
return F.prevalence_from_probabilities(classif_posteriors, binarize=False)
|
||||||
|
|
||||||
|
|
||||||
class PACC(AggregativeProbabilisticQuantifier):
|
class PACC(AggregativeSoftQuantifier, CorrectionbasedAggregativeQuantifier):
|
||||||
"""
|
"""
|
||||||
`Probabilistic Adjusted Classify & Count <https://ieeexplore.ieee.org/abstract/document/5694031>`_,
|
`Probabilistic Adjusted Classify & Count <https://ieeexplore.ieee.org/abstract/document/5694031>`_,
|
||||||
the probabilistic variant of ACC that relies on the posterior probabilities returned by a probabilistic classifier.
|
the probabilistic variant of ACC that relies on the posterior probabilities returned by a probabilistic classifier.
|
||||||
|
|
||||||
:param classifier: a sklearn's Estimator that generates a classifier
|
:param classifier: a sklearn's Estimator that generates a classifier
|
||||||
:param val_split: indicates the proportion of data to be used as a stratified held-out validation set in which the
|
:param val_split: specifies the data used for generating classifier predictions. This specification
|
||||||
misclassification rates are to be estimated.
|
can be made as float in (0, 1) indicating the proportion of stratified held-out validation set to
|
||||||
This parameter can be indicated as a real value (between 0 and 1, default 0.4), representing a proportion of
|
be extracted from the training set (default 0.4); or as an integer, indicating that the predictions
|
||||||
validation data, or as an integer, indicating that the misclassification rates should be estimated via
|
are to be generated in a `k`-fold cross-validation manner (with this integer indicating the value
|
||||||
`k`-fold cross validation (this integer stands for the number of folds `k`), or as a
|
for `k`). Alternatively, this set can be specified at fit time by indicating the exact set of data
|
||||||
:class:`quapy.data.base.LabelledCollection` (the split itself).
|
on which the predictions are to be generated.
|
||||||
:param n_jobs: number of parallel workers
|
:param n_jobs: number of parallel workers
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@ -364,16 +467,20 @@ class PACC(AggregativeProbabilisticQuantifier):
|
||||||
self.val_split = val_split
|
self.val_split = val_split
|
||||||
self.n_jobs = qp._get_njobs(n_jobs)
|
self.n_jobs = qp._get_njobs(n_jobs)
|
||||||
|
|
||||||
def aggregation_fit(self, classif_predictions: LabelledCollection):
|
def aggregation_fit(self, classif_predictions: LabelledCollection, data: LabelledCollection):
|
||||||
"""
|
"""
|
||||||
Estimates the misclassification rates
|
Estimates the misclassification rates
|
||||||
|
|
||||||
:param classif_predictions: classifier predictions with true labels
|
:param classif_predictions: classifier soft predictions with true labels
|
||||||
"""
|
"""
|
||||||
true_labels, posteriors = classif_predictions
|
posteriors, true_labels = classif_predictions.Xy
|
||||||
self.pcc = PCC(self.classifier)
|
self.pcc = PCC(self.classifier)
|
||||||
self.Pte_cond_estim_ = self.getPteCondEstim(self.classifier.classes_, true_labels, posteriors)
|
self.Pte_cond_estim_ = self.getPteCondEstim(self.classifier.classes_, true_labels, posteriors)
|
||||||
|
|
||||||
|
def aggregate(self, classif_posteriors):
|
||||||
|
prevs_estim = self.pcc.aggregate(classif_posteriors)
|
||||||
|
return ACC.solve_adjustment(self.Pte_cond_estim_, prevs_estim)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def getPteCondEstim(cls, classes, y, y_):
|
def getPteCondEstim(cls, classes, y, y_):
|
||||||
# estimate the matrix with entry (i,j) being the estimate of P(yi|yj), that is, the probability that a
|
# estimate the matrix with entry (i,j) being the estimate of P(yi|yj), that is, the probability that a
|
||||||
|
@ -387,15 +494,8 @@ class PACC(AggregativeProbabilisticQuantifier):
|
||||||
|
|
||||||
return confusion.T
|
return confusion.T
|
||||||
|
|
||||||
def aggregate(self, classif_posteriors):
|
|
||||||
prevs_estim = self.pcc.aggregate(classif_posteriors)
|
|
||||||
return ACC.solve_adjustment(self.Pte_cond_estim_, prevs_estim)
|
|
||||||
|
|
||||||
def classify(self, data):
|
class EMQ(AggregativeSoftQuantifier):
|
||||||
return self.pcc.classify(data)
|
|
||||||
|
|
||||||
|
|
||||||
class EMQ(AggregativeProbabilisticQuantifier):
|
|
||||||
"""
|
"""
|
||||||
`Expectation Maximization for Quantification <https://ieeexplore.ieee.org/abstract/document/6789744>`_ (EMQ),
|
`Expectation Maximization for Quantification <https://ieeexplore.ieee.org/abstract/document/6789744>`_ (EMQ),
|
||||||
aka `Saerens-Latinne-Decaestecker` (SLD) algorithm.
|
aka `Saerens-Latinne-Decaestecker` (SLD) algorithm.
|
||||||
|
@ -404,74 +504,30 @@ class EMQ(AggregativeProbabilisticQuantifier):
|
||||||
maximum-likelihood estimation, in a mutually recursive way, until convergence.
|
maximum-likelihood estimation, in a mutually recursive way, until convergence.
|
||||||
|
|
||||||
:param classifier: a sklearn's Estimator that generates a classifier
|
:param classifier: a sklearn's Estimator that generates a classifier
|
||||||
:param exact_train_prev: set to True (default) for using, as the initial observation, the true training prevalence;
|
|
||||||
or set to False for computing the training prevalence as an estimate, akin to PCC, i.e., as the expected
|
|
||||||
value of the posterior probabilities of the training instances as suggested in
|
|
||||||
`Alexandari et al. paper <http://proceedings.mlr.press/v119/alexandari20a.html>`_:
|
|
||||||
:param recalib: a string indicating the method of recalibration. Available choices include "nbvs" (No-Bias Vector
|
|
||||||
Scaling), "bcts" (Bias-Corrected Temperature Scaling), "ts" (Temperature Scaling), and "vs" (Vector Scaling).
|
|
||||||
The default value is None, indicating no recalibration.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
MAX_ITER = 1000
|
MAX_ITER = 1000
|
||||||
EPSILON = 1e-4
|
EPSILON = 1e-4
|
||||||
|
|
||||||
def __init__(self, classifier: BaseEstimator, exact_train_prev=True, recalib=None):
|
def __init__(self, classifier: BaseEstimator):
|
||||||
self.classifier = classifier
|
self.classifier = classifier
|
||||||
self.non_calibrated = classifier
|
|
||||||
self.exact_train_prev = exact_train_prev
|
|
||||||
self.recalib = recalib
|
|
||||||
|
|
||||||
def classifier_fit_predict(self, data: LabelledCollection, fit_classifier=True):
|
def aggregation_fit(self, classif_predictions: LabelledCollection, data: LabelledCollection):
|
||||||
self.classifier, true_labels, posteriors, classes, class_count = cross_generate_predictions(
|
self.train_prevalence = data.prevalence()
|
||||||
data, self.classifier, val_split, probabilistic=True, fit_classifier=fit_classifier, n_jobs=self.n_jobs
|
|
||||||
)
|
|
||||||
|
|
||||||
return (true_labels, posteriors)
|
|
||||||
|
|
||||||
if self.recalib is not None:
|
|
||||||
if self.recalib == 'nbvs':
|
|
||||||
self.classifier = NBVSCalibration(self.non_calibrated)
|
|
||||||
elif self.recalib == 'bcts':
|
|
||||||
self.classifier = BCTSCalibration(self.non_calibrated)
|
|
||||||
elif self.recalib == 'ts':
|
|
||||||
self.classifier = TSCalibration(self.non_calibrated)
|
|
||||||
elif self.recalib == 'vs':
|
|
||||||
self.classifier = VSCalibration(self.non_calibrated)
|
|
||||||
elif self.recalib == 'platt':
|
|
||||||
self.classifier = CalibratedClassifierCV(self.classifier, ensemble=False)
|
|
||||||
else:
|
|
||||||
raise ValueError('invalid param argument for recalibration method; available ones are '
|
|
||||||
'"nbvs", "bcts", "ts", and "vs".')
|
|
||||||
self.recalib = None
|
|
||||||
else:
|
|
||||||
self.classifier = self.non_calibrated
|
|
||||||
self.classifier, _ = _training_helper(self.classifier, data, fit_classifier, ensure_probabilistic=True)
|
|
||||||
if self.exact_train_prev:
|
|
||||||
self.train_prevalence = F.prevalence_from_labels(data.labels, self.classes_)
|
|
||||||
else:
|
|
||||||
self.train_prevalence = qp.model_selection.cross_val_predict(
|
|
||||||
quantifier=PCC(deepcopy(self.classifier)),
|
|
||||||
data=data,
|
|
||||||
nfolds=3,
|
|
||||||
random_state=0
|
|
||||||
)
|
|
||||||
return None
|
|
||||||
|
|
||||||
def aggregation_fit(self, classif_predictions: np.ndarray):
|
|
||||||
"""
|
|
||||||
Nothing to do here!
|
|
||||||
|
|
||||||
:param classif_predictions: this is actually None
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
def aggregate(self, classif_posteriors, epsilon=EPSILON):
|
def aggregate(self, classif_posteriors, epsilon=EPSILON):
|
||||||
priors, posteriors = self.EM(self.train_prevalence, classif_posteriors, epsilon)
|
priors, posteriors = self.EM(self.train_prevalence, classif_posteriors, epsilon)
|
||||||
return priors
|
return priors
|
||||||
|
|
||||||
def predict_proba(self, instances, epsilon=EPSILON):
|
def predict_proba(self, instances, epsilon=EPSILON):
|
||||||
classif_posteriors = self.classifier.predict_proba(instances)
|
"""
|
||||||
|
Returns the posterior probabilities updated by the EM algorithm.
|
||||||
|
|
||||||
|
:param instances: np.ndarray of shape `(n_instances, n_dimensions)`
|
||||||
|
:param epsilon: error tolerance
|
||||||
|
:return: np.ndarray of shape `(n_instances, n_classes)`
|
||||||
|
"""
|
||||||
|
classif_posteriors = self.classify(instances)
|
||||||
priors, posteriors = self.EM(self.train_prevalence, classif_posteriors, epsilon)
|
priors, posteriors = self.EM(self.train_prevalence, classif_posteriors, epsilon)
|
||||||
return posteriors
|
return posteriors
|
||||||
|
|
||||||
|
@ -514,7 +570,94 @@ class EMQ(AggregativeProbabilisticQuantifier):
|
||||||
return qs, ps
|
return qs, ps
|
||||||
|
|
||||||
|
|
||||||
class HDy(AggregativeProbabilisticQuantifier, BinaryQuantifier):
|
class EMQrecalib(AggregativeSoftQuantifier, CorrectionbasedAggregativeQuantifier):
|
||||||
|
"""
|
||||||
|
`Expectation Maximization for Quantification <https://ieeexplore.ieee.org/abstract/document/6789744>`_ (EMQ),
|
||||||
|
aka `Saerens-Latinne-Decaestecker` (SLD) algorithm, with the heuristics proposed by
|
||||||
|
`Alexandari et al. paper <http://proceedings.mlr.press/v119/alexandari20a.html>`_.
|
||||||
|
|
||||||
|
These heuristics consist of using, as the training prevalence, an estimate of it obtained via k-fold cross
|
||||||
|
validation (instead of the true training prevalence), and to recalibrate the posterior probabilities of
|
||||||
|
the classifier.
|
||||||
|
|
||||||
|
:param classifier: a sklearn's Estimator that generates a classifier
|
||||||
|
:param val_split: specifies the data used for generating classifier predictions. This specification
|
||||||
|
can be made as float in (0, 1) indicating the proportion of stratified held-out validation set to
|
||||||
|
be extracted from the training set (default 0.4); or as an integer, indicating that the predictions
|
||||||
|
are to be generated in a `k`-fold cross-validation manner (with this integer indicating the value
|
||||||
|
for `k`); or as a collection defining the specific set of data to use for validation.
|
||||||
|
Alternatively, this set can be specified at fit time by indicating the exact set of data
|
||||||
|
on which the predictions are to be generated.
|
||||||
|
:param exact_train_prev: set to True (default) for using, as the initial observation, the true training prevalence;
|
||||||
|
or set to False for computing the training prevalence as an estimate of it, i.e., as the expected
|
||||||
|
value of the posterior probabilities of the training instances
|
||||||
|
:param recalib: a string indicating the method of recalibration.
|
||||||
|
Available choices include "nbvs" (No-Bias Vector Scaling), "bcts" (Bias-Corrected Temperature Scaling,
|
||||||
|
default), "ts" (Temperature Scaling), and "vs" (Vector Scaling).
|
||||||
|
:param n_jobs: number of parallel workers
|
||||||
|
"""
|
||||||
|
|
||||||
|
MAX_ITER = 1000
|
||||||
|
EPSILON = 1e-4
|
||||||
|
|
||||||
|
def __init__(self, classifier: BaseEstimator, val_split=5, exact_train_prev=False, recalib='bcts', n_jobs=None):
|
||||||
|
self.classifier = classifier
|
||||||
|
self.val_split = val_split
|
||||||
|
self.exact_train_prev = exact_train_prev
|
||||||
|
self.recalib = recalib
|
||||||
|
self.n_jobs = n_jobs
|
||||||
|
|
||||||
|
def classify(self, instances):
|
||||||
|
"""
|
||||||
|
Provides the posterior probabilities for the given instances. If the classifier is
|
||||||
|
recalibrated, then these posteriors will be recalibrated accordingly.
|
||||||
|
|
||||||
|
:param instances: array-like of shape `(n_instances, n_dimensions,)`
|
||||||
|
:return: np.ndarray of shape `(n_instances, n_classes,)` with posterior probabilities
|
||||||
|
"""
|
||||||
|
posteriors = self.classifier.predict_proba(instances)
|
||||||
|
if hasattr(self, 'calibration_function') and self.calibration_function is not None:
|
||||||
|
posteriors = self.calibration_function(posteriors)
|
||||||
|
return posteriors
|
||||||
|
|
||||||
|
def aggregation_fit(self, classif_predictions: LabelledCollection, data: LabelledCollection):
|
||||||
|
if self.recalib is not None:
|
||||||
|
P, y = classif_predictions.Xy
|
||||||
|
if self.recalib == 'nbvs':
|
||||||
|
calibrator = NoBiasVectorScaling()
|
||||||
|
elif self.recalib == 'bcts':
|
||||||
|
calibrator = TempScaling(bias_positions='all')
|
||||||
|
elif self.recalib == 'ts':
|
||||||
|
calibrator = TempScaling()
|
||||||
|
elif self.recalib == 'vs':
|
||||||
|
calibrator = VectorScaling()
|
||||||
|
else:
|
||||||
|
raise ValueError('invalid param argument for recalibration method; available ones are '
|
||||||
|
'"nbvs", "bcts", "ts", and "vs".')
|
||||||
|
|
||||||
|
self.calibration_function = calibrator(P, np.eye(data.n_classes)[y], posterior_supplied=True)
|
||||||
|
|
||||||
|
if self.exact_train_prev:
|
||||||
|
self.train_prevalence = F.prevalence_from_labels(data.labels, self.classes_)
|
||||||
|
else:
|
||||||
|
if self.recalib is not None:
|
||||||
|
train_posteriors = self.classify(data.X)
|
||||||
|
else:
|
||||||
|
train_posteriors = classif_predictions.X
|
||||||
|
|
||||||
|
self.train_prevalence = np.mean(train_posteriors, axis=0)
|
||||||
|
|
||||||
|
def aggregate(self, classif_posteriors, epsilon=EPSILON):
|
||||||
|
priors, posteriors = EMQ.EM(self.train_prevalence, classif_posteriors, epsilon)
|
||||||
|
return priors
|
||||||
|
|
||||||
|
def predict_proba(self, instances, epsilon=EPSILON):
|
||||||
|
classif_posteriors = self.classify(instances)
|
||||||
|
priors, posteriors = EMQ.EM(self.train_prevalence, classif_posteriors, epsilon)
|
||||||
|
return posteriors
|
||||||
|
|
||||||
|
|
||||||
|
class HDy(AggregativeSoftQuantifier, BinaryQuantifier, CorrectionbasedAggregativeQuantifier):
|
||||||
"""
|
"""
|
||||||
`Hellinger Distance y <https://www.sciencedirect.com/science/article/pii/S0020025512004069>`_ (HDy).
|
`Hellinger Distance y <https://www.sciencedirect.com/science/article/pii/S0020025512004069>`_ (HDy).
|
||||||
HDy is a probabilistic method for training binary quantifiers, that models quantification as the problem of
|
HDy is a probabilistic method for training binary quantifiers, that models quantification as the problem of
|
||||||
|
@ -533,7 +676,7 @@ class HDy(AggregativeProbabilisticQuantifier, BinaryQuantifier):
|
||||||
self.classifier = classifier
|
self.classifier = classifier
|
||||||
self.val_split = val_split
|
self.val_split = val_split
|
||||||
|
|
||||||
def fit(self, data: LabelledCollection, fit_classifier=True, val_split: Union[float, LabelledCollection] = None):
|
def aggregation_fit(self, classif_predictions: LabelledCollection, data: LabelledCollection):
|
||||||
"""
|
"""
|
||||||
Trains a HDy quantifier.
|
Trains a HDy quantifier.
|
||||||
|
|
||||||
|
@ -544,22 +687,23 @@ class HDy(AggregativeProbabilisticQuantifier, BinaryQuantifier):
|
||||||
:class:`quapy.data.base.LabelledCollection` indicating the validation set itself
|
:class:`quapy.data.base.LabelledCollection` indicating the validation set itself
|
||||||
:return: self
|
:return: self
|
||||||
"""
|
"""
|
||||||
if val_split is None:
|
|
||||||
val_split = self.val_split
|
|
||||||
|
|
||||||
self._check_binary(data, self.__class__.__name__)
|
self._check_binary(data, self.__class__.__name__)
|
||||||
self.classifier, validation = _training_helper(
|
P, y = classif_predictions.Xy
|
||||||
self.classifier, data, fit_classifier, ensure_probabilistic=True, val_split=val_split)
|
Px = P[:, 1] # takes only the P(y=+1|x)
|
||||||
Px = self.classify(validation.instances)[:, 1] # takes only the P(y=+1|x)
|
self.Pxy1 = Px[y == self.classifier.classes_[1]]
|
||||||
self.Pxy1 = Px[validation.labels == self.classifier.classes_[1]]
|
self.Pxy0 = Px[y == self.classifier.classes_[0]]
|
||||||
self.Pxy0 = Px[validation.labels == self.classifier.classes_[0]]
|
|
||||||
# pre-compute the histogram for positive and negative examples
|
# pre-compute the histogram for positive and negative examples
|
||||||
self.bins = np.linspace(10, 110, 11, dtype=int) # [10, 20, 30, ..., 100, 110]
|
self.bins = np.linspace(10, 110, 11, dtype=int) # [10, 20, 30, ..., 100, 110]
|
||||||
|
|
||||||
def hist(P, bins):
|
def hist(P, bins):
|
||||||
h = np.histogram(P, bins=bins, range=(0, 1), density=True)[0]
|
h = np.histogram(P, bins=bins, range=(0, 1), density=True)[0]
|
||||||
return h / h.sum()
|
return h / h.sum()
|
||||||
|
|
||||||
self.Pxy1_density = {bins: hist(self.Pxy1, bins) for bins in self.bins}
|
self.Pxy1_density = {bins: hist(self.Pxy1, bins) for bins in self.bins}
|
||||||
self.Pxy0_density = {bins: hist(self.Pxy0, bins) for bins in self.bins}
|
self.Pxy0_density = {bins: hist(self.Pxy0, bins) for bins in self.bins}
|
||||||
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def aggregate(self, classif_posteriors):
|
def aggregate(self, classif_posteriors):
|
||||||
|
@ -583,7 +727,7 @@ class HDy(AggregativeProbabilisticQuantifier, BinaryQuantifier):
|
||||||
# at small steps (modern implementations resort to an optimization procedure,
|
# at small steps (modern implementations resort to an optimization procedure,
|
||||||
# see class DistributionMatching)
|
# see class DistributionMatching)
|
||||||
prev_selected, min_dist = None, None
|
prev_selected, min_dist = None, None
|
||||||
for prev in F.prevalence_linspace(n_prevalences=100, repeats=1, smooth_limits_epsilon=0.0):
|
for prev in F.prevalence_linspace(n_prevalences=101, repeats=1, smooth_limits_epsilon=0.0):
|
||||||
Px_train = prev * Pxy1_density + (1 - prev) * Pxy0_density
|
Px_train = prev * Pxy1_density + (1 - prev) * Pxy0_density
|
||||||
hdy = F.HellingerDistance(Px_train, Px_test)
|
hdy = F.HellingerDistance(Px_train, Px_test)
|
||||||
if prev_selected is None or hdy < min_dist:
|
if prev_selected is None or hdy < min_dist:
|
||||||
|
@ -594,7 +738,7 @@ class HDy(AggregativeProbabilisticQuantifier, BinaryQuantifier):
|
||||||
return np.asarray([1 - class1_prev, class1_prev])
|
return np.asarray([1 - class1_prev, class1_prev])
|
||||||
|
|
||||||
|
|
||||||
class DyS(AggregativeProbabilisticQuantifier, BinaryQuantifier):
|
class DyS(AggregativeSoftQuantifier, BinaryQuantifier):
|
||||||
"""
|
"""
|
||||||
`DyS framework <https://ojs.aaai.org/index.php/AAAI/article/view/4376>`_ (DyS).
|
`DyS framework <https://ojs.aaai.org/index.php/AAAI/article/view/4376>`_ (DyS).
|
||||||
DyS is a generalization of HDy method, using a Ternary Search in order to find the prevalence that
|
DyS is a generalization of HDy method, using a Ternary Search in order to find the prevalence that
|
||||||
|
@ -661,7 +805,7 @@ class DyS(AggregativeProbabilisticQuantifier, BinaryQuantifier):
|
||||||
return np.asarray([1 - class1_prev, class1_prev])
|
return np.asarray([1 - class1_prev, class1_prev])
|
||||||
|
|
||||||
|
|
||||||
class SMM(AggregativeProbabilisticQuantifier, BinaryQuantifier):
|
class SMM(AggregativeSoftQuantifier, BinaryQuantifier):
|
||||||
"""
|
"""
|
||||||
`SMM method <https://ieeexplore.ieee.org/document/9260028>`_ (SMM).
|
`SMM method <https://ieeexplore.ieee.org/document/9260028>`_ (SMM).
|
||||||
SMM is a simplification of matching distribution methods where the representation of the examples
|
SMM is a simplification of matching distribution methods where the representation of the examples
|
||||||
|
@ -700,7 +844,7 @@ class SMM(AggregativeProbabilisticQuantifier, BinaryQuantifier):
|
||||||
return np.asarray([1 - class1_prev, class1_prev])
|
return np.asarray([1 - class1_prev, class1_prev])
|
||||||
|
|
||||||
|
|
||||||
class DMy(AggregativeProbabilisticQuantifier):
|
class DMy(AggregativeSoftQuantifier, CorrectionbasedAggregativeQuantifier):
|
||||||
"""
|
"""
|
||||||
Generic Distribution Matching quantifier for binary or multiclass quantification based on the space of posterior
|
Generic Distribution Matching quantifier for binary or multiclass quantification based on the space of posterior
|
||||||
probabilities. This implementation takes the number of bins, the divergence, and the possibility to work on CDF
|
probabilities. This implementation takes the number of bins, the divergence, and the possibility to work on CDF
|
||||||
|
@ -736,7 +880,7 @@ class DMy(AggregativeProbabilisticQuantifier):
|
||||||
from quapy.method.meta import MedianEstimator
|
from quapy.method.meta import MedianEstimator
|
||||||
|
|
||||||
hdy = DMy(classifier=classifier, val_split=val_split, search='linear_search', divergence='HD')
|
hdy = DMy(classifier=classifier, val_split=val_split, search='linear_search', divergence='HD')
|
||||||
hdy = MedianEstimator(hdy, param_grid={'nbins': np.linspace(10, 110, 11).astype(int)}, n_jobs=n_jobs)
|
hdy = AggregativeMedianEstimator(hdy, param_grid={'nbins': np.linspace(10, 110, 11).astype(int)}, n_jobs=n_jobs)
|
||||||
return hdy
|
return hdy
|
||||||
|
|
||||||
def __get_distributions(self, posteriors):
|
def __get_distributions(self, posteriors):
|
||||||
|
@ -755,7 +899,7 @@ class DMy(AggregativeProbabilisticQuantifier):
|
||||||
distributions = np.cumsum(distributions, axis=1)
|
distributions = np.cumsum(distributions, axis=1)
|
||||||
return distributions
|
return distributions
|
||||||
|
|
||||||
def classifier_fit_predict(self, data: LabelledCollection, fit_classifier=True, val_split: Union[float, LabelledCollection] = None):
|
def aggregation_fit(self, classif_predictions: LabelledCollection, data: LabelledCollection):
|
||||||
"""
|
"""
|
||||||
Trains the classifier (if requested) and generates the validation distributions out of the training data.
|
Trains the classifier (if requested) and generates the validation distributions out of the training data.
|
||||||
The validation distributions have shape `(n, ch, nbins)`, with `n` the number of classes, `ch` the number of
|
The validation distributions have shape `(n, ch, nbins)`, with `n` the number of classes, `ch` the number of
|
||||||
|
@ -771,21 +915,13 @@ class DMy(AggregativeProbabilisticQuantifier):
|
||||||
indicating the validation set itself, or an int indicating the number k of folds to be used in kFCV
|
indicating the validation set itself, or an int indicating the number k of folds to be used in kFCV
|
||||||
to estimate the parameters
|
to estimate the parameters
|
||||||
"""
|
"""
|
||||||
if val_split is None:
|
posteriors, true_labels = classif_predictions.Xy
|
||||||
val_split = self.val_split
|
|
||||||
|
|
||||||
self.classifier, true_labels, posteriors, classes, class_count = cross_generate_predictions(
|
|
||||||
data, self.classifier, val_split, probabilistic=True, fit_classifier=fit_classifier, n_jobs=self.n_jobs
|
|
||||||
)
|
|
||||||
|
|
||||||
return (true_labels, posteriors)
|
|
||||||
|
|
||||||
def aggregation_fit(self, classif_predictions):
|
|
||||||
true_labels, posteriors = classif_predictions
|
|
||||||
n_classes = len(self.classifier.classes_)
|
n_classes = len(self.classifier.classes_)
|
||||||
|
|
||||||
self.validation_distribution = np.asarray(
|
self.validation_distribution = qp.util.parallel(
|
||||||
[self.__get_distributions(posteriors[true_labels == cat]) for cat in range(n_classes)]
|
func=self.__get_distributions,
|
||||||
|
args=[posteriors[true_labels==cat] for cat in range(n_classes)],
|
||||||
|
n_jobs=self.n_jobs
|
||||||
)
|
)
|
||||||
|
|
||||||
def aggregate(self, posteriors: np.ndarray):
|
def aggregate(self, posteriors: np.ndarray):
|
||||||
|
@ -1252,7 +1388,7 @@ class OneVsAllAggregative(OneVsAllGeneric, AggregativeQuantifier):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
classif_predictions = self._parallel(self._delayed_binary_classification, instances)
|
classif_predictions = self._parallel(self._delayed_binary_classification, instances)
|
||||||
if isinstance(self.binary_quantifier, AggregativeProbabilisticQuantifier):
|
if isinstance(self.binary_quantifier, AggregativeSoftQuantifier):
|
||||||
return np.swapaxes(classif_predictions, 0, 1)
|
return np.swapaxes(classif_predictions, 0, 1)
|
||||||
else:
|
else:
|
||||||
return classif_predictions.T
|
return classif_predictions.T
|
||||||
|
@ -1269,6 +1405,130 @@ class OneVsAllAggregative(OneVsAllGeneric, AggregativeQuantifier):
|
||||||
return self.dict_binary_quantifiers[c].aggregate(classif_predictions[:, c])[1]
|
return self.dict_binary_quantifiers[c].aggregate(classif_predictions[:, c])[1]
|
||||||
|
|
||||||
|
|
||||||
|
class AggregativeMedianEstimator(BinaryQuantifier):
|
||||||
|
"""
|
||||||
|
This method is a meta-quantifier that returns, as the estimated class prevalence values, the median of the
|
||||||
|
estimation returned by differently (hyper)parameterized base quantifiers.
|
||||||
|
The median of unit-vectors is only guaranteed to be a unit-vector for n=2 dimensions,
|
||||||
|
i.e., in cases of binary quantification.
|
||||||
|
|
||||||
|
:param base_quantifier: the base, binary quantifier
|
||||||
|
:param random_state: a seed to be set before fitting any base quantifier (default None)
|
||||||
|
:param param_grid: the grid or parameters towards which the median will be computed
|
||||||
|
:param n_jobs: number of parllel workes
|
||||||
|
"""
|
||||||
|
def __init__(self, base_quantifier: AggregativeQuantifier, param_grid: dict, random_state=None, n_jobs=None):
|
||||||
|
self.base_quantifier = base_quantifier
|
||||||
|
self.param_grid = param_grid
|
||||||
|
self.random_state = random_state
|
||||||
|
self.n_jobs = qp._get_njobs(n_jobs)
|
||||||
|
|
||||||
|
def get_params(self, deep=True):
|
||||||
|
return self.base_quantifier.get_params(deep)
|
||||||
|
|
||||||
|
def set_params(self, **params):
|
||||||
|
self.base_quantifier.set_params(**params)
|
||||||
|
|
||||||
|
def _delayed_fit(self, args):
|
||||||
|
with qp.util.temp_seed(self.random_state):
|
||||||
|
params, training = args
|
||||||
|
model = deepcopy(self.base_quantifier)
|
||||||
|
model.set_params(**params)
|
||||||
|
model.fit(training)
|
||||||
|
return model
|
||||||
|
|
||||||
|
def _delayed_fit_classifier(self, args):
|
||||||
|
with qp.util.temp_seed(self.random_state):
|
||||||
|
print('enter job')
|
||||||
|
cls_params, training, kwargs = args
|
||||||
|
model = deepcopy(self.base_quantifier)
|
||||||
|
model.set_params(**cls_params)
|
||||||
|
predictions = model.classifier_fit_predict(training, **kwargs)
|
||||||
|
print('exit job')
|
||||||
|
return (model, predictions)
|
||||||
|
|
||||||
|
def _delayed_fit_aggregation(self, args):
|
||||||
|
with qp.util.temp_seed(self.random_state):
|
||||||
|
print('\tenter job')
|
||||||
|
((model, predictions), q_params), training = args
|
||||||
|
model = deepcopy(model)
|
||||||
|
print('fitaggr', model, predictions, len(predictions), print(self.training))
|
||||||
|
model.set_params(**q_params)
|
||||||
|
model.aggregation_fit(predictions, training)
|
||||||
|
print('\texit job')
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def fit(self, training: LabelledCollection, **kwargs):
|
||||||
|
import itertools
|
||||||
|
|
||||||
|
self._check_binary(training, self.__class__.__name__)
|
||||||
|
|
||||||
|
if isinstance(self.base_quantifier, AggregativeQuantifier):
|
||||||
|
cls_configs, q_configs = qp.model_selection.group_params(self.param_grid)
|
||||||
|
|
||||||
|
if len(cls_configs) > 1:
|
||||||
|
models_preds = qp.util.parallel(
|
||||||
|
self._delayed_fit_classifier,
|
||||||
|
((params, training, kwargs) for params in cls_configs),
|
||||||
|
seed=qp.environ.get('_R_SEED', None),
|
||||||
|
n_jobs=self.n_jobs,
|
||||||
|
asarray=False
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
print('only 1')
|
||||||
|
model = self.base_quantifier
|
||||||
|
model.set_params(**cls_configs[0])
|
||||||
|
predictions = model.classifier_fit_predict(training, **kwargs)
|
||||||
|
models_preds = [(model, predictions)]
|
||||||
|
|
||||||
|
self.training = training
|
||||||
|
|
||||||
|
self.models = []
|
||||||
|
print('WITHOUT PARALLEL JOBS')
|
||||||
|
for ((model, predictions), q_params) in itertools.product(models_preds, q_configs):
|
||||||
|
print('\tenter job')
|
||||||
|
model = deepcopy(model)
|
||||||
|
print('fitaggr', model, predictions, len(predictions), print(self.training))
|
||||||
|
model.set_params(**q_params)
|
||||||
|
model.aggregation_fit(predictions, training)
|
||||||
|
self.models.append(model)
|
||||||
|
print('\texit job')
|
||||||
|
|
||||||
|
|
||||||
|
# self.models = qp.util.parallel(
|
||||||
|
# self._delayed_fit_aggregation,
|
||||||
|
# ((setup, training) for setup in itertools.product(models_preds, q_configs)),
|
||||||
|
# seed=qp.environ.get('_R_SEED', None),
|
||||||
|
# n_jobs=self.n_jobs,
|
||||||
|
# asarray=False
|
||||||
|
# )
|
||||||
|
else:
|
||||||
|
configs = qp.model_selection.expand_grid(self.param_grid)
|
||||||
|
self.models = qp.util.parallel(
|
||||||
|
self._delayed_fit,
|
||||||
|
((params, training) for params in configs),
|
||||||
|
seed=qp.environ.get('_R_SEED', None),
|
||||||
|
n_jobs=self.n_jobs,
|
||||||
|
asarray=False
|
||||||
|
)
|
||||||
|
return self
|
||||||
|
|
||||||
|
def _delayed_predict(self, args):
|
||||||
|
model, instances = args
|
||||||
|
return model.quantify(instances)
|
||||||
|
|
||||||
|
def quantify(self, instances):
|
||||||
|
prev_preds = qp.util.parallel(
|
||||||
|
self._delayed_predict,
|
||||||
|
((model, instances) for model in self.models),
|
||||||
|
seed=qp.environ.get('_R_SEED', None),
|
||||||
|
n_jobs=self.n_jobs,
|
||||||
|
asarray=False
|
||||||
|
)
|
||||||
|
prev_preds = np.asarray(prev_preds)
|
||||||
|
return np.median(prev_preds, axis=0)
|
||||||
|
|
||||||
#---------------------------------------------------------------
|
#---------------------------------------------------------------
|
||||||
# aliases
|
# aliases
|
||||||
#---------------------------------------------------------------
|
#---------------------------------------------------------------
|
||||||
|
|
|
@ -12,7 +12,7 @@ from quapy import functional as F
|
||||||
from quapy.data import LabelledCollection
|
from quapy.data import LabelledCollection
|
||||||
from quapy.model_selection import GridSearchQ
|
from quapy.model_selection import GridSearchQ
|
||||||
from quapy.method.base import BaseQuantifier, BinaryQuantifier
|
from quapy.method.base import BaseQuantifier, BinaryQuantifier
|
||||||
from quapy.method.aggregative import CC, ACC, PACC, HDy, EMQ
|
from quapy.method.aggregative import CC, ACC, PACC, HDy, EMQ, AggregativeQuantifier
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from . import neural
|
from . import neural
|
||||||
|
@ -26,6 +26,65 @@ else:
|
||||||
QuaNet = "QuaNet is not available due to missing torch package"
|
QuaNet = "QuaNet is not available due to missing torch package"
|
||||||
|
|
||||||
|
|
||||||
|
class MedianEstimator2(BinaryQuantifier):
|
||||||
|
"""
|
||||||
|
This method is a meta-quantifier that returns, as the estimated class prevalence values, the median of the
|
||||||
|
estimation returned by differently (hyper)parameterized base quantifiers.
|
||||||
|
The median of unit-vectors is only guaranteed to be a unit-vector for n=2 dimensions,
|
||||||
|
i.e., in cases of binary quantification.
|
||||||
|
|
||||||
|
:param base_quantifier: the base, binary quantifier
|
||||||
|
:param random_state: a seed to be set before fitting any base quantifier (default None)
|
||||||
|
:param param_grid: the grid or parameters towards which the median will be computed
|
||||||
|
:param n_jobs: number of parllel workes
|
||||||
|
"""
|
||||||
|
def __init__(self, base_quantifier: BinaryQuantifier, param_grid: dict, random_state=None, n_jobs=None):
|
||||||
|
self.base_quantifier = base_quantifier
|
||||||
|
self.param_grid = param_grid
|
||||||
|
self.random_state = random_state
|
||||||
|
self.n_jobs = qp._get_njobs(n_jobs)
|
||||||
|
|
||||||
|
def get_params(self, deep=True):
|
||||||
|
return self.base_quantifier.get_params(deep)
|
||||||
|
|
||||||
|
def set_params(self, **params):
|
||||||
|
self.base_quantifier.set_params(**params)
|
||||||
|
|
||||||
|
def _delayed_fit(self, args):
|
||||||
|
with qp.util.temp_seed(self.random_state):
|
||||||
|
params, training = args
|
||||||
|
model = deepcopy(self.base_quantifier)
|
||||||
|
model.set_params(**params)
|
||||||
|
model.fit(training)
|
||||||
|
return model
|
||||||
|
|
||||||
|
def fit(self, training: LabelledCollection):
|
||||||
|
self._check_binary(training, self.__class__.__name__)
|
||||||
|
|
||||||
|
configs = qp.model_selection.expand_grid(self.param_grid)
|
||||||
|
self.models = qp.util.parallel(
|
||||||
|
self._delayed_fit,
|
||||||
|
((params, training) for params in configs),
|
||||||
|
seed=qp.environ.get('_R_SEED', None),
|
||||||
|
n_jobs=self.n_jobs
|
||||||
|
)
|
||||||
|
return self
|
||||||
|
|
||||||
|
def _delayed_predict(self, args):
|
||||||
|
model, instances = args
|
||||||
|
return model.quantify(instances)
|
||||||
|
|
||||||
|
def quantify(self, instances):
|
||||||
|
prev_preds = qp.util.parallel(
|
||||||
|
self._delayed_predict,
|
||||||
|
((model, instances) for model in self.models),
|
||||||
|
seed=qp.environ.get('_R_SEED', None),
|
||||||
|
n_jobs=self.n_jobs
|
||||||
|
)
|
||||||
|
prev_preds = np.asarray(prev_preds)
|
||||||
|
return np.median(prev_preds, axis=0)
|
||||||
|
|
||||||
|
|
||||||
class MedianEstimator(BinaryQuantifier):
|
class MedianEstimator(BinaryQuantifier):
|
||||||
"""
|
"""
|
||||||
This method is a meta-quantifier that returns, as the estimated class prevalence values, the median of the
|
This method is a meta-quantifier that returns, as the estimated class prevalence values, the median of the
|
||||||
|
@ -58,17 +117,64 @@ class MedianEstimator(BinaryQuantifier):
|
||||||
model.fit(training)
|
model.fit(training)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
def _delayed_fit_classifier(self, args):
|
||||||
|
with qp.util.temp_seed(self.random_state):
|
||||||
|
print('enter job')
|
||||||
|
cls_params, training = args
|
||||||
|
model = deepcopy(self.base_quantifier)
|
||||||
|
model.set_params(**cls_params)
|
||||||
|
predictions = model.classifier_fit_predict(training, predict_on=model.val_split)
|
||||||
|
print('exit job')
|
||||||
|
return (model, predictions)
|
||||||
|
|
||||||
|
def _delayed_fit_aggregation(self, args):
|
||||||
|
with qp.util.temp_seed(self.random_state):
|
||||||
|
print('\tenter job')
|
||||||
|
((model, predictions), q_params), training = args
|
||||||
|
model = deepcopy(model)
|
||||||
|
model.set_params(**q_params)
|
||||||
|
model.aggregation_fit(predictions, training)
|
||||||
|
print('\texit job')
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
def fit(self, training: LabelledCollection):
|
def fit(self, training: LabelledCollection):
|
||||||
self._check_binary(training, self.__class__.__name__)
|
self._check_binary(training, self.__class__.__name__)
|
||||||
params_keys = list(self.param_grid.keys())
|
|
||||||
params_values = list(self.param_grid.values())
|
if isinstance(self.base_quantifier, AggregativeQuantifier):
|
||||||
hyper = [dict({k: val[i] for i, k in enumerate(params_keys)}) for val in itertools.product(*params_values)]
|
cls_configs, q_configs = qp.model_selection.group_params(self.param_grid)
|
||||||
self.models = qp.util.parallel(
|
|
||||||
self._delayed_fit,
|
if len(cls_configs) > 1:
|
||||||
((params, training) for params in hyper),
|
models_preds = qp.util.parallel(
|
||||||
seed=qp.environ.get('_R_SEED', None),
|
self._delayed_fit_classifier,
|
||||||
n_jobs=self.n_jobs
|
((params, training) for params in cls_configs),
|
||||||
)
|
seed=qp.environ.get('_R_SEED', None),
|
||||||
|
n_jobs=self.n_jobs,
|
||||||
|
asarray=False
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
print('only 1')
|
||||||
|
model = self.base_quantifier
|
||||||
|
model.set_params(**cls_configs[0])
|
||||||
|
predictions = model.classifier_fit_predict(training, predict_on=model.val_split)
|
||||||
|
models_preds = [(model, predictions)]
|
||||||
|
|
||||||
|
self.models = qp.util.parallel(
|
||||||
|
self._delayed_fit_aggregation,
|
||||||
|
((setup, training) for setup in itertools.product(models_preds, q_configs)),
|
||||||
|
seed=qp.environ.get('_R_SEED', None),
|
||||||
|
n_jobs=self.n_jobs,
|
||||||
|
asarray=False
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
configs = qp.model_selection.expand_grid(self.param_grid)
|
||||||
|
self.models = qp.util.parallel(
|
||||||
|
self._delayed_fit,
|
||||||
|
((params, training) for params in configs),
|
||||||
|
seed=qp.environ.get('_R_SEED', None),
|
||||||
|
n_jobs=self.n_jobs,
|
||||||
|
asarray=False
|
||||||
|
)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def _delayed_predict(self, args):
|
def _delayed_predict(self, args):
|
||||||
|
@ -80,13 +186,13 @@ class MedianEstimator(BinaryQuantifier):
|
||||||
self._delayed_predict,
|
self._delayed_predict,
|
||||||
((model, instances) for model in self.models),
|
((model, instances) for model in self.models),
|
||||||
seed=qp.environ.get('_R_SEED', None),
|
seed=qp.environ.get('_R_SEED', None),
|
||||||
n_jobs=self.n_jobs
|
n_jobs=self.n_jobs,
|
||||||
|
asarray=False
|
||||||
)
|
)
|
||||||
prev_preds = np.asarray(prev_preds)
|
prev_preds = np.asarray(prev_preds)
|
||||||
return np.median(prev_preds, axis=0)
|
return np.median(prev_preds, axis=0)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class Ensemble(BaseQuantifier):
|
class Ensemble(BaseQuantifier):
|
||||||
VALID_POLICIES = {'ave', 'ptr', 'ds'} | qp.error.QUANTIFICATION_ERROR_NAMES
|
VALID_POLICIES = {'ave', 'ptr', 'ds'} | qp.error.QUANTIFICATION_ERROR_NAMES
|
||||||
|
|
||||||
|
|
|
@ -194,7 +194,7 @@ class QuaNetTrainer(BaseQuantifier):
|
||||||
label_predictions = np.argmax(posteriors, axis=-1)
|
label_predictions = np.argmax(posteriors, axis=-1)
|
||||||
prevs_estim = []
|
prevs_estim = []
|
||||||
for quantifier in self.quantifiers.values():
|
for quantifier in self.quantifiers.values():
|
||||||
predictions = posteriors if isinstance(quantifier, AggregativeProbabilisticQuantifier) else label_predictions
|
predictions = posteriors if isinstance(quantifier, AggregativeSoftQuantifier) else label_predictions
|
||||||
prevs_estim.extend(quantifier.aggregate(predictions))
|
prevs_estim.extend(quantifier.aggregate(predictions))
|
||||||
|
|
||||||
# there is no real need for adding static estims like the TPR or FPR from training since those are constant
|
# there is no real need for adding static estims like the TPR or FPR from training since those are constant
|
||||||
|
|
|
@ -76,8 +76,6 @@ class GridSearchQ(BaseQuantifier):
|
||||||
:param training: the training set on which to optimize the hyperparameters
|
:param training: the training set on which to optimize the hyperparameters
|
||||||
:return: self
|
:return: self
|
||||||
"""
|
"""
|
||||||
params_keys = list(self.param_grid.keys())
|
|
||||||
params_values = list(self.param_grid.values())
|
|
||||||
|
|
||||||
protocol = self.protocol
|
protocol = self.protocol
|
||||||
|
|
||||||
|
@ -86,12 +84,13 @@ class GridSearchQ(BaseQuantifier):
|
||||||
|
|
||||||
tinit = time()
|
tinit = time()
|
||||||
|
|
||||||
hyper = [dict({k: val[i] for i, k in enumerate(params_keys)}) for val in itertools.product(*params_values)]
|
configs = expand_grid(self.param_grid)
|
||||||
|
|
||||||
self._sout(f'starting model selection with {self.n_jobs =}')
|
self._sout(f'starting model selection with {self.n_jobs =}')
|
||||||
#pass a seed to parallel so it is set in clild processes
|
#pass a seed to parallel so it is set in child processes
|
||||||
scores = qp.util.parallel(
|
scores = qp.util.parallel(
|
||||||
self._delayed_eval,
|
self._delayed_eval,
|
||||||
((params, training) for params in hyper),
|
((params, training) for params in configs),
|
||||||
seed=qp.environ.get('_R_SEED', None),
|
seed=qp.environ.get('_R_SEED', None),
|
||||||
n_jobs=self.n_jobs
|
n_jobs=self.n_jobs
|
||||||
)
|
)
|
||||||
|
@ -204,8 +203,6 @@ class GridSearchQ(BaseQuantifier):
|
||||||
raise ValueError('best_model called before fit')
|
raise ValueError('best_model called before fit')
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def cross_val_predict(quantifier: BaseQuantifier, data: LabelledCollection, nfolds=3, random_state=0):
|
def cross_val_predict(quantifier: BaseQuantifier, data: LabelledCollection, nfolds=3, random_state=0):
|
||||||
"""
|
"""
|
||||||
Akin to `scikit-learn's cross_val_predict <https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.cross_val_predict.html>`_
|
Akin to `scikit-learn's cross_val_predict <https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.cross_val_predict.html>`_
|
||||||
|
@ -229,3 +226,43 @@ def cross_val_predict(quantifier: BaseQuantifier, data: LabelledCollection, nfol
|
||||||
return total_prev
|
return total_prev
|
||||||
|
|
||||||
|
|
||||||
|
def expand_grid(param_grid: dict):
|
||||||
|
"""
|
||||||
|
Expands a param_grid dictionary as a list of configurations.
|
||||||
|
Example:
|
||||||
|
|
||||||
|
>>> combinations = expand_grid({'A': [1, 10, 100], 'B': [True, False]})
|
||||||
|
>>> print(combinations)
|
||||||
|
>>> [{'A': 1, 'B': True}, {'A': 1, 'B': False}, {'A': 10, 'B': True}, {'A': 10, 'B': False}, {'A': 100, 'B': True}, {'A': 100, 'B': False}]
|
||||||
|
|
||||||
|
:param param_grid: dictionary with keys representing hyper-parameter names, and values representing the range
|
||||||
|
to explore for that hyper-parameter
|
||||||
|
:return: a list of configurations, i.e., combinations of hyper-parameter assignments in the grid.
|
||||||
|
"""
|
||||||
|
params_keys = list(param_grid.keys())
|
||||||
|
params_values = list(param_grid.values())
|
||||||
|
configs = [{k: combs[i] for i, k in enumerate(params_keys)} for combs in itertools.product(*params_values)]
|
||||||
|
return configs
|
||||||
|
|
||||||
|
|
||||||
|
def group_params(param_grid: dict):
|
||||||
|
"""
|
||||||
|
Partitions a param_grid dictionary as two lists of configurations, one for the classifier-specific
|
||||||
|
hyper-parameters, and another for que quantifier-specific hyper-parameters
|
||||||
|
|
||||||
|
:param param_grid: dictionary with keys representing hyper-parameter names, and values representing the range
|
||||||
|
to explore for that hyper-parameter
|
||||||
|
:return: two expanded grids of configurations, one for the classifier, another for the quantifier
|
||||||
|
"""
|
||||||
|
classifier_params, quantifier_params = {}, {}
|
||||||
|
for key, values in param_grid.items():
|
||||||
|
if key.startswith('classifier__') or key == 'val_split':
|
||||||
|
classifier_params[key] = values
|
||||||
|
else:
|
||||||
|
quantifier_params[key] = values
|
||||||
|
|
||||||
|
classifier_configs = expand_grid(classifier_params)
|
||||||
|
quantifier_configs = expand_grid(quantifier_params)
|
||||||
|
|
||||||
|
return classifier_configs, quantifier_configs
|
||||||
|
|
||||||
|
|
|
@ -22,9 +22,9 @@ class HierarchyTestCase(unittest.TestCase):
|
||||||
def test_probabilistic(self):
|
def test_probabilistic(self):
|
||||||
lr = LogisticRegression()
|
lr = LogisticRegression()
|
||||||
for m in [CC(lr), ACC(lr)]:
|
for m in [CC(lr), ACC(lr)]:
|
||||||
self.assertEqual(isinstance(m, AggregativeProbabilisticQuantifier), False)
|
self.assertEqual(isinstance(m, AggregativeSoftQuantifier), False)
|
||||||
for m in [PCC(lr), PACC(lr)]:
|
for m in [PCC(lr), PACC(lr)]:
|
||||||
self.assertEqual(isinstance(m, AggregativeProbabilisticQuantifier), True)
|
self.assertEqual(isinstance(m, AggregativeSoftQuantifier), True)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|
|
@ -38,7 +38,7 @@ def map_parallel(func, args, n_jobs):
|
||||||
return list(itertools.chain.from_iterable(results))
|
return list(itertools.chain.from_iterable(results))
|
||||||
|
|
||||||
|
|
||||||
def parallel(func, args, n_jobs, seed=None):
|
def parallel(func, args, n_jobs, seed=None, asarray=True):
|
||||||
"""
|
"""
|
||||||
A wrapper of multiprocessing:
|
A wrapper of multiprocessing:
|
||||||
|
|
||||||
|
@ -58,9 +58,12 @@ def parallel(func, args, n_jobs, seed=None):
|
||||||
stack.enter_context(qp.util.temp_seed(seed))
|
stack.enter_context(qp.util.temp_seed(seed))
|
||||||
return func(*args)
|
return func(*args)
|
||||||
|
|
||||||
return Parallel(n_jobs=n_jobs)(
|
out = Parallel(n_jobs=n_jobs)(
|
||||||
delayed(func_dec)(qp.environ, None if seed is None else seed+i, args_i) for i, args_i in enumerate(args)
|
delayed(func_dec)(qp.environ, None if seed is None else seed+i, args_i) for i, args_i in enumerate(args)
|
||||||
)
|
)
|
||||||
|
if asarray:
|
||||||
|
out = np.asarray(out)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
@contextlib.contextmanager
|
@contextlib.contextmanager
|
||||||
|
|
Loading…
Reference in New Issue