From c6fb46cf70bf1ec6e42cfb2a42f35e2210bc83bc Mon Sep 17 00:00:00 2001 From: Alejandro Moreo Date: Sun, 11 Jan 2026 19:00:13 +0100 Subject: [PATCH] added prior --- BayesianKDEy/_bayeisan_kdey.py | 43 +++++++++++++++++++++++--------- BayesianKDEy/full_experiments.py | 1 + 2 files changed, 32 insertions(+), 12 deletions(-) diff --git a/BayesianKDEy/_bayeisan_kdey.py b/BayesianKDEy/_bayeisan_kdey.py index 9a4f171..197deaf 100644 --- a/BayesianKDEy/_bayeisan_kdey.py +++ b/BayesianKDEy/_bayeisan_kdey.py @@ -12,7 +12,8 @@ from quapy.method.aggregative import AggregativeSoftQuantifier from tqdm import tqdm import quapy.functional as F import emcee - +from collections.abc import Iterable +from numbers import Number import jax import jax.numpy as jnp @@ -58,6 +59,7 @@ class BayesianKDEy(AggregativeSoftQuantifier, KDEBase, WithConfidenceABC): step_size=0.05, temperature=1., engine='numpyro', + prior='uniform', verbose: bool = False): if num_warmup <= 0: @@ -66,6 +68,8 @@ class BayesianKDEy(AggregativeSoftQuantifier, KDEBase, WithConfidenceABC): raise ValueError(f'parameter {num_samples=} must be a positive integer') assert explore in ['simplex', 'clr', 'ilr'], \ f'unexpected value for param {explore=}; valid ones are "simplex", "clr", and "ilr"' + assert prior == 'uniform' or (isinstance(prior, Iterable) and all(isinstance(v, Number) for v in prior)), \ + f'wrong type for {prior=}; expected "uniform" or an array-like of real values' # assert temperature>0., f'temperature must be >0' assert engine in ['rw-mh', 'emcee', 'numpyro'] @@ -81,6 +85,7 @@ class BayesianKDEy(AggregativeSoftQuantifier, KDEBase, WithConfidenceABC): self.step_size = step_size self.temperature = temperature self.engine = engine + self.prior = prior self.verbose = verbose def aggregation_fit(self, classif_predictions, labels): @@ -89,8 +94,12 @@ class BayesianKDEy(AggregativeSoftQuantifier, KDEBase, WithConfidenceABC): def aggregate(self, classif_predictions): if self.engine == 'rw-mh': + if self.prior != 'uniform': + raise RuntimeError('prior is not yet implemented in rw-mh') self.prevalence_samples = self._bayesian_kde(classif_predictions, init=None, verbose=self.verbose) elif self.engine == 'emcee': + if self.prior != 'uniform': + raise RuntimeError('prior is not yet implemented in emcee') self.prevalence_samples = self._bayesian_emcee(classif_predictions) elif self.engine == 'numpyro': self.prevalence_samples = self._bayesian_numpyro(classif_predictions) @@ -237,6 +246,11 @@ class BayesianKDEy(AggregativeSoftQuantifier, KDEBase, WithConfidenceABC): # move to jax test_densities = jnp.array(test_densities) + n_classes = X_probs.shape[-1] + if self.prior == 'uniform': + alpha = [1.]*n_classes + else: + alpha = self.prior kernel = NUTS(self._numpyro_model) mcmc = MCMC( @@ -248,7 +262,7 @@ class BayesianKDEy(AggregativeSoftQuantifier, KDEBase, WithConfidenceABC): ) rng_key = jax.random.PRNGKey(self.mcmc_seed) - mcmc.run(rng_key, test_densities) + mcmc.run(rng_key, test_densities=test_densities, alpha=alpha) samples_z = mcmc.get_samples()["z"] @@ -258,26 +272,33 @@ class BayesianKDEy(AggregativeSoftQuantifier, KDEBase, WithConfidenceABC): return samples_prev - def _numpyro_model(self, test_densities): + def _numpyro_model(self, test_densities, alpha): """ - test_densities: shape (C, N) + test_densities: shape (n_classes, n_instances,) """ - C = test_densities.shape[0] + n_classes = test_densities.shape[0] ilr = ILRtransformation(jax_mode=True) - # sample in unconstrained R^{C-1} + # sample in unconstrained R^(n_classes-1) z = numpyro.sample( "z", - dist.Normal(0.0, 1.0).expand([C - 1]) + dist.Normal(0.0, 1.0).expand([n_classes - 1]) ) - prev = ilr.inverse(z) # simplex, shape (C,) + prev = ilr.inverse(z) # simplex, shape (n_classes,) + + # prior + if alpha is not None: + alpha = jnp.asarray(alpha) + numpyro.factor( + 'log_prior', dist.Dirichlet(alpha.log_prob(prev)) + ) + # if alpha is None, then this corresponds to a weak logistic-normal prior # likelihood likelihoods = jnp.dot(prev, test_densities) numpyro.factor( - "loglik", - (1.0 / self.temperature) * jnp.sum(jnp.log(likelihoods + 1e-10)) + "loglik", (1.0 / self.temperature) * jnp.sum(jnp.log(likelihoods + 1e-10)) ) @@ -285,8 +306,6 @@ def in_simplex(x): return np.all(x >= 0) and np.isclose(x.sum(), 1) - - class ILRtransformation(F.CompositionalTransformation): def __init__(self, jax_mode=False): self.jax_mode = jax_mode diff --git a/BayesianKDEy/full_experiments.py b/BayesianKDEy/full_experiments.py index 47c4155..3ca28f1 100644 --- a/BayesianKDEy/full_experiments.py +++ b/BayesianKDEy/full_experiments.py @@ -85,6 +85,7 @@ def methods(): yield f'BaKDE-numpyro-T*', KDEyML(LR()), kdey_hyper, lambda hyper: BayesianKDEy(mcmc_seed=0, engine='numpyro', temperature=None, **hyper), multiclass_method yield f'BaKDE-Ait-numpyro', KDEyCLR(LR()), kdey_hyper_clr, lambda hyper: BayesianKDEy(kernel='aitchison', mcmc_seed=0, engine='numpyro', **hyper), multiclass_method yield f'BaKDE-Ait-numpyro-T*', KDEyCLR(LR()), kdey_hyper_clr, lambda hyper: BayesianKDEy(kernel='aitchison', mcmc_seed=0, engine='numpyro', temperature=None, **hyper), multiclass_method + yield f'BaKDE-Ait-numpyro-T*-U', KDEyCLR(LR()), kdey_hyper_clr, lambda hyper: BayesianKDEy(kernel='aitchison', mcmc_seed=0, engine='numpyro', temperature=None, prior='uniform', **hyper), multiclass_method yield f'BaKDE-Ait-numpyro-T*ILR', KDEyCLR(LR()), kdey_hyper_clr, lambda hyper: BayesianKDEy(kernel='aitchison', mcmc_seed=0, engine='numpyro', temperature=None, region='ellipse-ilr', **hyper), multiclass_method yield f'BaKDE-numpyro-T10', KDEyML(LR()), kdey_hyper, lambda hyper: BayesianKDEy(mcmc_seed=0, engine='numpyro', temperature=10., **hyper), multiclass_method yield f'BaKDE-numpyro*CLR', KDEyCLR(LR()), kdey_hyper_clr, lambda hyper: BayesianKDEy(kernel='aitchison', mcmc_seed=0, engine='numpyro', **hyper), multiclass_method