This commit is contained in:
Alejandro Moreo Fernandez 2026-01-13 12:32:28 +01:00
parent ca981836b4
commit 724e1b13a0
1 changed files with 3 additions and 2 deletions

View File

@ -68,7 +68,8 @@ class BayesianKDEy(AggregativeSoftQuantifier, KDEBase, WithConfidenceABC):
raise ValueError(f'parameter {num_samples=} must be a positive integer') raise ValueError(f'parameter {num_samples=} must be a positive integer')
assert explore in ['simplex', 'clr', 'ilr'], \ assert explore in ['simplex', 'clr', 'ilr'], \
f'unexpected value for param {explore=}; valid ones are "simplex", "clr", and "ilr"' f'unexpected value for param {explore=}; valid ones are "simplex", "clr", and "ilr"'
assert prior == 'uniform' or (isinstance(prior, Iterable) and all(isinstance(v, Number) for v in prior)), \ assert ((isinstance(prior, str) and prior == 'uniform') or
(isinstance(prior, Iterable) and all(isinstance(v, Number) for v in prior))), \
f'wrong type for {prior=}; expected "uniform" or an array-like of real values' f'wrong type for {prior=}; expected "uniform" or an array-like of real values'
# assert temperature>0., f'temperature must be >0' # assert temperature>0., f'temperature must be >0'
assert engine in ['rw-mh', 'emcee', 'numpyro'] assert engine in ['rw-mh', 'emcee', 'numpyro']
@ -247,7 +248,7 @@ class BayesianKDEy(AggregativeSoftQuantifier, KDEBase, WithConfidenceABC):
# move to jax # move to jax
test_densities = jnp.array(test_densities) test_densities = jnp.array(test_densities)
n_classes = X_probs.shape[-1] n_classes = X_probs.shape[-1]
if self.prior == 'uniform': if isinstance(self.prior, str) and self.prior == 'uniform':
alpha = [1.]*n_classes alpha = [1.]*n_classes
else: else:
alpha = self.prior alpha = self.prior