bugfix
This commit is contained in:
parent
ca981836b4
commit
724e1b13a0
|
|
@ -68,7 +68,8 @@ class BayesianKDEy(AggregativeSoftQuantifier, KDEBase, WithConfidenceABC):
|
||||||
raise ValueError(f'parameter {num_samples=} must be a positive integer')
|
raise ValueError(f'parameter {num_samples=} must be a positive integer')
|
||||||
assert explore in ['simplex', 'clr', 'ilr'], \
|
assert explore in ['simplex', 'clr', 'ilr'], \
|
||||||
f'unexpected value for param {explore=}; valid ones are "simplex", "clr", and "ilr"'
|
f'unexpected value for param {explore=}; valid ones are "simplex", "clr", and "ilr"'
|
||||||
assert prior == 'uniform' or (isinstance(prior, Iterable) and all(isinstance(v, Number) for v in prior)), \
|
assert ((isinstance(prior, str) and prior == 'uniform') or
|
||||||
|
(isinstance(prior, Iterable) and all(isinstance(v, Number) for v in prior))), \
|
||||||
f'wrong type for {prior=}; expected "uniform" or an array-like of real values'
|
f'wrong type for {prior=}; expected "uniform" or an array-like of real values'
|
||||||
# assert temperature>0., f'temperature must be >0'
|
# assert temperature>0., f'temperature must be >0'
|
||||||
assert engine in ['rw-mh', 'emcee', 'numpyro']
|
assert engine in ['rw-mh', 'emcee', 'numpyro']
|
||||||
|
|
@ -247,7 +248,7 @@ class BayesianKDEy(AggregativeSoftQuantifier, KDEBase, WithConfidenceABC):
|
||||||
# move to jax
|
# move to jax
|
||||||
test_densities = jnp.array(test_densities)
|
test_densities = jnp.array(test_densities)
|
||||||
n_classes = X_probs.shape[-1]
|
n_classes = X_probs.shape[-1]
|
||||||
if self.prior == 'uniform':
|
if isinstance(self.prior, str) and self.prior == 'uniform':
|
||||||
alpha = [1.]*n_classes
|
alpha = [1.]*n_classes
|
||||||
else:
|
else:
|
||||||
alpha = self.prior
|
alpha = self.prior
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue