added prior

This commit is contained in:
Alejandro Moreo Fernandez 2026-01-11 19:00:13 +01:00
parent ae9503a43b
commit c6fb46cf70
2 changed files with 32 additions and 12 deletions

View File

@ -12,7 +12,8 @@ from quapy.method.aggregative import AggregativeSoftQuantifier
from tqdm import tqdm from tqdm import tqdm
import quapy.functional as F import quapy.functional as F
import emcee import emcee
from collections.abc import Iterable
from numbers import Number
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
@ -58,6 +59,7 @@ class BayesianKDEy(AggregativeSoftQuantifier, KDEBase, WithConfidenceABC):
step_size=0.05, step_size=0.05,
temperature=1., temperature=1.,
engine='numpyro', engine='numpyro',
prior='uniform',
verbose: bool = False): verbose: bool = False):
if num_warmup <= 0: if num_warmup <= 0:
@ -66,6 +68,8 @@ class BayesianKDEy(AggregativeSoftQuantifier, KDEBase, WithConfidenceABC):
raise ValueError(f'parameter {num_samples=} must be a positive integer') raise ValueError(f'parameter {num_samples=} must be a positive integer')
assert explore in ['simplex', 'clr', 'ilr'], \ assert explore in ['simplex', 'clr', 'ilr'], \
f'unexpected value for param {explore=}; valid ones are "simplex", "clr", and "ilr"' f'unexpected value for param {explore=}; valid ones are "simplex", "clr", and "ilr"'
assert prior == 'uniform' or (isinstance(prior, Iterable) and all(isinstance(v, Number) for v in prior)), \
f'wrong type for {prior=}; expected "uniform" or an array-like of real values'
# assert temperature>0., f'temperature must be >0' # assert temperature>0., f'temperature must be >0'
assert engine in ['rw-mh', 'emcee', 'numpyro'] assert engine in ['rw-mh', 'emcee', 'numpyro']
@ -81,6 +85,7 @@ class BayesianKDEy(AggregativeSoftQuantifier, KDEBase, WithConfidenceABC):
self.step_size = step_size self.step_size = step_size
self.temperature = temperature self.temperature = temperature
self.engine = engine self.engine = engine
self.prior = prior
self.verbose = verbose self.verbose = verbose
def aggregation_fit(self, classif_predictions, labels): def aggregation_fit(self, classif_predictions, labels):
@ -89,8 +94,12 @@ class BayesianKDEy(AggregativeSoftQuantifier, KDEBase, WithConfidenceABC):
def aggregate(self, classif_predictions): def aggregate(self, classif_predictions):
if self.engine == 'rw-mh': if self.engine == 'rw-mh':
if self.prior != 'uniform':
raise RuntimeError('prior is not yet implemented in rw-mh')
self.prevalence_samples = self._bayesian_kde(classif_predictions, init=None, verbose=self.verbose) self.prevalence_samples = self._bayesian_kde(classif_predictions, init=None, verbose=self.verbose)
elif self.engine == 'emcee': elif self.engine == 'emcee':
if self.prior != 'uniform':
raise RuntimeError('prior is not yet implemented in emcee')
self.prevalence_samples = self._bayesian_emcee(classif_predictions) self.prevalence_samples = self._bayesian_emcee(classif_predictions)
elif self.engine == 'numpyro': elif self.engine == 'numpyro':
self.prevalence_samples = self._bayesian_numpyro(classif_predictions) self.prevalence_samples = self._bayesian_numpyro(classif_predictions)
@ -237,6 +246,11 @@ class BayesianKDEy(AggregativeSoftQuantifier, KDEBase, WithConfidenceABC):
# move to jax # move to jax
test_densities = jnp.array(test_densities) test_densities = jnp.array(test_densities)
n_classes = X_probs.shape[-1]
if self.prior == 'uniform':
alpha = [1.]*n_classes
else:
alpha = self.prior
kernel = NUTS(self._numpyro_model) kernel = NUTS(self._numpyro_model)
mcmc = MCMC( mcmc = MCMC(
@ -248,7 +262,7 @@ class BayesianKDEy(AggregativeSoftQuantifier, KDEBase, WithConfidenceABC):
) )
rng_key = jax.random.PRNGKey(self.mcmc_seed) rng_key = jax.random.PRNGKey(self.mcmc_seed)
mcmc.run(rng_key, test_densities) mcmc.run(rng_key, test_densities=test_densities, alpha=alpha)
samples_z = mcmc.get_samples()["z"] samples_z = mcmc.get_samples()["z"]
@ -258,26 +272,33 @@ class BayesianKDEy(AggregativeSoftQuantifier, KDEBase, WithConfidenceABC):
return samples_prev return samples_prev
def _numpyro_model(self, test_densities): def _numpyro_model(self, test_densities, alpha):
""" """
test_densities: shape (C, N) test_densities: shape (n_classes, n_instances,)
""" """
C = test_densities.shape[0] n_classes = test_densities.shape[0]
ilr = ILRtransformation(jax_mode=True) ilr = ILRtransformation(jax_mode=True)
# sample in unconstrained R^{C-1} # sample in unconstrained R^(n_classes-1)
z = numpyro.sample( z = numpyro.sample(
"z", "z",
dist.Normal(0.0, 1.0).expand([C - 1]) dist.Normal(0.0, 1.0).expand([n_classes - 1])
) )
prev = ilr.inverse(z) # simplex, shape (C,) prev = ilr.inverse(z) # simplex, shape (n_classes,)
# prior
if alpha is not None:
alpha = jnp.asarray(alpha)
numpyro.factor(
'log_prior', dist.Dirichlet(alpha.log_prob(prev))
)
# if alpha is None, then this corresponds to a weak logistic-normal prior
# likelihood # likelihood
likelihoods = jnp.dot(prev, test_densities) likelihoods = jnp.dot(prev, test_densities)
numpyro.factor( numpyro.factor(
"loglik", "loglik", (1.0 / self.temperature) * jnp.sum(jnp.log(likelihoods + 1e-10))
(1.0 / self.temperature) * jnp.sum(jnp.log(likelihoods + 1e-10))
) )
@ -285,8 +306,6 @@ def in_simplex(x):
return np.all(x >= 0) and np.isclose(x.sum(), 1) return np.all(x >= 0) and np.isclose(x.sum(), 1)
class ILRtransformation(F.CompositionalTransformation): class ILRtransformation(F.CompositionalTransformation):
def __init__(self, jax_mode=False): def __init__(self, jax_mode=False):
self.jax_mode = jax_mode self.jax_mode = jax_mode

View File

@ -85,6 +85,7 @@ def methods():
yield f'BaKDE-numpyro-T*', KDEyML(LR()), kdey_hyper, lambda hyper: BayesianKDEy(mcmc_seed=0, engine='numpyro', temperature=None, **hyper), multiclass_method yield f'BaKDE-numpyro-T*', KDEyML(LR()), kdey_hyper, lambda hyper: BayesianKDEy(mcmc_seed=0, engine='numpyro', temperature=None, **hyper), multiclass_method
yield f'BaKDE-Ait-numpyro', KDEyCLR(LR()), kdey_hyper_clr, lambda hyper: BayesianKDEy(kernel='aitchison', mcmc_seed=0, engine='numpyro', **hyper), multiclass_method yield f'BaKDE-Ait-numpyro', KDEyCLR(LR()), kdey_hyper_clr, lambda hyper: BayesianKDEy(kernel='aitchison', mcmc_seed=0, engine='numpyro', **hyper), multiclass_method
yield f'BaKDE-Ait-numpyro-T*', KDEyCLR(LR()), kdey_hyper_clr, lambda hyper: BayesianKDEy(kernel='aitchison', mcmc_seed=0, engine='numpyro', temperature=None, **hyper), multiclass_method yield f'BaKDE-Ait-numpyro-T*', KDEyCLR(LR()), kdey_hyper_clr, lambda hyper: BayesianKDEy(kernel='aitchison', mcmc_seed=0, engine='numpyro', temperature=None, **hyper), multiclass_method
yield f'BaKDE-Ait-numpyro-T*-U', KDEyCLR(LR()), kdey_hyper_clr, lambda hyper: BayesianKDEy(kernel='aitchison', mcmc_seed=0, engine='numpyro', temperature=None, prior='uniform', **hyper), multiclass_method
yield f'BaKDE-Ait-numpyro-T*ILR', KDEyCLR(LR()), kdey_hyper_clr, lambda hyper: BayesianKDEy(kernel='aitchison', mcmc_seed=0, engine='numpyro', temperature=None, region='ellipse-ilr', **hyper), multiclass_method yield f'BaKDE-Ait-numpyro-T*ILR', KDEyCLR(LR()), kdey_hyper_clr, lambda hyper: BayesianKDEy(kernel='aitchison', mcmc_seed=0, engine='numpyro', temperature=None, region='ellipse-ilr', **hyper), multiclass_method
yield f'BaKDE-numpyro-T10', KDEyML(LR()), kdey_hyper, lambda hyper: BayesianKDEy(mcmc_seed=0, engine='numpyro', temperature=10., **hyper), multiclass_method yield f'BaKDE-numpyro-T10', KDEyML(LR()), kdey_hyper, lambda hyper: BayesianKDEy(mcmc_seed=0, engine='numpyro', temperature=10., **hyper), multiclass_method
yield f'BaKDE-numpyro*CLR', KDEyCLR(LR()), kdey_hyper_clr, lambda hyper: BayesianKDEy(kernel='aitchison', mcmc_seed=0, engine='numpyro', **hyper), multiclass_method yield f'BaKDE-numpyro*CLR', KDEyCLR(LR()), kdey_hyper_clr, lambda hyper: BayesianKDEy(kernel='aitchison', mcmc_seed=0, engine='numpyro', **hyper), multiclass_method