added prior
This commit is contained in:
parent
ae9503a43b
commit
c6fb46cf70
|
|
@ -12,7 +12,8 @@ from quapy.method.aggregative import AggregativeSoftQuantifier
|
|||
from tqdm import tqdm
|
||||
import quapy.functional as F
|
||||
import emcee
|
||||
|
||||
from collections.abc import Iterable
|
||||
from numbers import Number
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
|
|
@ -58,6 +59,7 @@ class BayesianKDEy(AggregativeSoftQuantifier, KDEBase, WithConfidenceABC):
|
|||
step_size=0.05,
|
||||
temperature=1.,
|
||||
engine='numpyro',
|
||||
prior='uniform',
|
||||
verbose: bool = False):
|
||||
|
||||
if num_warmup <= 0:
|
||||
|
|
@ -66,6 +68,8 @@ class BayesianKDEy(AggregativeSoftQuantifier, KDEBase, WithConfidenceABC):
|
|||
raise ValueError(f'parameter {num_samples=} must be a positive integer')
|
||||
assert explore in ['simplex', 'clr', 'ilr'], \
|
||||
f'unexpected value for param {explore=}; valid ones are "simplex", "clr", and "ilr"'
|
||||
assert prior == 'uniform' or (isinstance(prior, Iterable) and all(isinstance(v, Number) for v in prior)), \
|
||||
f'wrong type for {prior=}; expected "uniform" or an array-like of real values'
|
||||
# assert temperature>0., f'temperature must be >0'
|
||||
assert engine in ['rw-mh', 'emcee', 'numpyro']
|
||||
|
||||
|
|
@ -81,6 +85,7 @@ class BayesianKDEy(AggregativeSoftQuantifier, KDEBase, WithConfidenceABC):
|
|||
self.step_size = step_size
|
||||
self.temperature = temperature
|
||||
self.engine = engine
|
||||
self.prior = prior
|
||||
self.verbose = verbose
|
||||
|
||||
def aggregation_fit(self, classif_predictions, labels):
|
||||
|
|
@ -89,8 +94,12 @@ class BayesianKDEy(AggregativeSoftQuantifier, KDEBase, WithConfidenceABC):
|
|||
|
||||
def aggregate(self, classif_predictions):
|
||||
if self.engine == 'rw-mh':
|
||||
if self.prior != 'uniform':
|
||||
raise RuntimeError('prior is not yet implemented in rw-mh')
|
||||
self.prevalence_samples = self._bayesian_kde(classif_predictions, init=None, verbose=self.verbose)
|
||||
elif self.engine == 'emcee':
|
||||
if self.prior != 'uniform':
|
||||
raise RuntimeError('prior is not yet implemented in emcee')
|
||||
self.prevalence_samples = self._bayesian_emcee(classif_predictions)
|
||||
elif self.engine == 'numpyro':
|
||||
self.prevalence_samples = self._bayesian_numpyro(classif_predictions)
|
||||
|
|
@ -237,6 +246,11 @@ class BayesianKDEy(AggregativeSoftQuantifier, KDEBase, WithConfidenceABC):
|
|||
|
||||
# move to jax
|
||||
test_densities = jnp.array(test_densities)
|
||||
n_classes = X_probs.shape[-1]
|
||||
if self.prior == 'uniform':
|
||||
alpha = [1.]*n_classes
|
||||
else:
|
||||
alpha = self.prior
|
||||
|
||||
kernel = NUTS(self._numpyro_model)
|
||||
mcmc = MCMC(
|
||||
|
|
@ -248,7 +262,7 @@ class BayesianKDEy(AggregativeSoftQuantifier, KDEBase, WithConfidenceABC):
|
|||
)
|
||||
|
||||
rng_key = jax.random.PRNGKey(self.mcmc_seed)
|
||||
mcmc.run(rng_key, test_densities)
|
||||
mcmc.run(rng_key, test_densities=test_densities, alpha=alpha)
|
||||
|
||||
samples_z = mcmc.get_samples()["z"]
|
||||
|
||||
|
|
@ -258,26 +272,33 @@ class BayesianKDEy(AggregativeSoftQuantifier, KDEBase, WithConfidenceABC):
|
|||
|
||||
return samples_prev
|
||||
|
||||
def _numpyro_model(self, test_densities):
|
||||
def _numpyro_model(self, test_densities, alpha):
|
||||
"""
|
||||
test_densities: shape (C, N)
|
||||
test_densities: shape (n_classes, n_instances,)
|
||||
"""
|
||||
C = test_densities.shape[0]
|
||||
n_classes = test_densities.shape[0]
|
||||
ilr = ILRtransformation(jax_mode=True)
|
||||
|
||||
# sample in unconstrained R^{C-1}
|
||||
# sample in unconstrained R^(n_classes-1)
|
||||
z = numpyro.sample(
|
||||
"z",
|
||||
dist.Normal(0.0, 1.0).expand([C - 1])
|
||||
dist.Normal(0.0, 1.0).expand([n_classes - 1])
|
||||
)
|
||||
|
||||
prev = ilr.inverse(z) # simplex, shape (C,)
|
||||
prev = ilr.inverse(z) # simplex, shape (n_classes,)
|
||||
|
||||
# prior
|
||||
if alpha is not None:
|
||||
alpha = jnp.asarray(alpha)
|
||||
numpyro.factor(
|
||||
'log_prior', dist.Dirichlet(alpha.log_prob(prev))
|
||||
)
|
||||
# if alpha is None, then this corresponds to a weak logistic-normal prior
|
||||
|
||||
# likelihood
|
||||
likelihoods = jnp.dot(prev, test_densities)
|
||||
numpyro.factor(
|
||||
"loglik",
|
||||
(1.0 / self.temperature) * jnp.sum(jnp.log(likelihoods + 1e-10))
|
||||
"loglik", (1.0 / self.temperature) * jnp.sum(jnp.log(likelihoods + 1e-10))
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -285,8 +306,6 @@ def in_simplex(x):
|
|||
return np.all(x >= 0) and np.isclose(x.sum(), 1)
|
||||
|
||||
|
||||
|
||||
|
||||
class ILRtransformation(F.CompositionalTransformation):
|
||||
def __init__(self, jax_mode=False):
|
||||
self.jax_mode = jax_mode
|
||||
|
|
|
|||
|
|
@ -85,6 +85,7 @@ def methods():
|
|||
yield f'BaKDE-numpyro-T*', KDEyML(LR()), kdey_hyper, lambda hyper: BayesianKDEy(mcmc_seed=0, engine='numpyro', temperature=None, **hyper), multiclass_method
|
||||
yield f'BaKDE-Ait-numpyro', KDEyCLR(LR()), kdey_hyper_clr, lambda hyper: BayesianKDEy(kernel='aitchison', mcmc_seed=0, engine='numpyro', **hyper), multiclass_method
|
||||
yield f'BaKDE-Ait-numpyro-T*', KDEyCLR(LR()), kdey_hyper_clr, lambda hyper: BayesianKDEy(kernel='aitchison', mcmc_seed=0, engine='numpyro', temperature=None, **hyper), multiclass_method
|
||||
yield f'BaKDE-Ait-numpyro-T*-U', KDEyCLR(LR()), kdey_hyper_clr, lambda hyper: BayesianKDEy(kernel='aitchison', mcmc_seed=0, engine='numpyro', temperature=None, prior='uniform', **hyper), multiclass_method
|
||||
yield f'BaKDE-Ait-numpyro-T*ILR', KDEyCLR(LR()), kdey_hyper_clr, lambda hyper: BayesianKDEy(kernel='aitchison', mcmc_seed=0, engine='numpyro', temperature=None, region='ellipse-ilr', **hyper), multiclass_method
|
||||
yield f'BaKDE-numpyro-T10', KDEyML(LR()), kdey_hyper, lambda hyper: BayesianKDEy(mcmc_seed=0, engine='numpyro', temperature=10., **hyper), multiclass_method
|
||||
yield f'BaKDE-numpyro*CLR', KDEyCLR(LR()), kdey_hyper_clr, lambda hyper: BayesianKDEy(kernel='aitchison', mcmc_seed=0, engine='numpyro', **hyper), multiclass_method
|
||||
|
|
|
|||
Loading…
Reference in New Issue