diff --git a/BayesianKDEy/_bayeisan_kdey.py b/BayesianKDEy/_bayeisan_kdey.py index 3c58900..356b259 100644 --- a/BayesianKDEy/_bayeisan_kdey.py +++ b/BayesianKDEy/_bayeisan_kdey.py @@ -68,7 +68,8 @@ class BayesianKDEy(AggregativeSoftQuantifier, KDEBase, WithConfidenceABC): raise ValueError(f'parameter {num_samples=} must be a positive integer') assert explore in ['simplex', 'clr', 'ilr'], \ f'unexpected value for param {explore=}; valid ones are "simplex", "clr", and "ilr"' - assert prior == 'uniform' or (isinstance(prior, Iterable) and all(isinstance(v, Number) for v in prior)), \ + assert ((isinstance(prior, str) and prior == 'uniform') or + (isinstance(prior, Iterable) and all(isinstance(v, Number) for v in prior))), \ f'wrong type for {prior=}; expected "uniform" or an array-like of real values' # assert temperature>0., f'temperature must be >0' assert engine in ['rw-mh', 'emcee', 'numpyro'] @@ -247,7 +248,7 @@ class BayesianKDEy(AggregativeSoftQuantifier, KDEBase, WithConfidenceABC): # move to jax test_densities = jnp.array(test_densities) n_classes = X_probs.shape[-1] - if self.prior == 'uniform': + if isinstance(self.prior, str) and self.prior == 'uniform': alpha = [1.]*n_classes else: alpha = self.prior