bugfix
This commit is contained in:
parent
ca981836b4
commit
724e1b13a0
|
|
@ -68,7 +68,8 @@ class BayesianKDEy(AggregativeSoftQuantifier, KDEBase, WithConfidenceABC):
|
|||
raise ValueError(f'parameter {num_samples=} must be a positive integer')
|
||||
assert explore in ['simplex', 'clr', 'ilr'], \
|
||||
f'unexpected value for param {explore=}; valid ones are "simplex", "clr", and "ilr"'
|
||||
assert prior == 'uniform' or (isinstance(prior, Iterable) and all(isinstance(v, Number) for v in prior)), \
|
||||
assert ((isinstance(prior, str) and prior == 'uniform') or
|
||||
(isinstance(prior, Iterable) and all(isinstance(v, Number) for v in prior))), \
|
||||
f'wrong type for {prior=}; expected "uniform" or an array-like of real values'
|
||||
# assert temperature>0., f'temperature must be >0'
|
||||
assert engine in ['rw-mh', 'emcee', 'numpyro']
|
||||
|
|
@ -247,7 +248,7 @@ class BayesianKDEy(AggregativeSoftQuantifier, KDEBase, WithConfidenceABC):
|
|||
# move to jax
|
||||
test_densities = jnp.array(test_densities)
|
||||
n_classes = X_probs.shape[-1]
|
||||
if self.prior == 'uniform':
|
||||
if isinstance(self.prior, str) and self.prior == 'uniform':
|
||||
alpha = [1.]*n_classes
|
||||
else:
|
||||
alpha = self.prior
|
||||
|
|
|
|||
Loading…
Reference in New Issue