338 lines
14 KiB
Python
338 lines
14 KiB
Python
from sklearn.base import BaseEstimator
|
|
import numpy as np
|
|
|
|
from BayesianKDEy.commons import ILRtransformation, in_simplex
|
|
from quapy.method._kdey import KDEBase
|
|
from quapy.method.confidence import WithConfidenceABC, ConfidenceRegionABC
|
|
from quapy.functional import CLRtransformation
|
|
from quapy.method.aggregative import AggregativeSoftQuantifier
|
|
from tqdm import tqdm
|
|
import quapy.functional as F
|
|
import emcee
|
|
from collections.abc import Iterable
|
|
from numbers import Number
|
|
|
|
import jax
|
|
import jax.numpy as jnp
|
|
from jax.scipy.special import logsumexp
|
|
import numpyro
|
|
import numpyro.distributions as dist
|
|
from numpyro.infer import MCMC, NUTS
|
|
|
|
|
|
class BayesianKDEy(AggregativeSoftQuantifier, KDEBase, WithConfidenceABC):
|
|
"""
|
|
`Bayesian version of KDEy.
|
|
|
|
:param classifier: a scikit-learn's BaseEstimator, or None, in which case the classifier is taken to be
|
|
the one indicated in `qp.environ['DEFAULT_CLS']`
|
|
:param val_split: specifies the data used for generating classifier predictions. This specification
|
|
can be made as float in (0, 1) indicating the proportion of stratified held-out validation set to
|
|
be extracted from the training set; or as an integer (default 5), indicating that the predictions
|
|
are to be generated in a `k`-fold cross-validation manner (with this integer indicating the value
|
|
for `k`); or as a tuple `(X,y)` defining the specific set of data to use for validation. Set to
|
|
None when the method does not require any validation data, in order to avoid that some portion of
|
|
the training data be wasted.
|
|
:param num_warmup: number of warmup iterations for the MCMC sampler (default 500)
|
|
:param num_samples: number of samples to draw from the posterior (default 1000)
|
|
:param mcmc_seed: random seed for the MCMC sampler (default 0)
|
|
:param confidence_level: float in [0,1] to construct a confidence region around the point estimate (default 0.95)
|
|
:param region: string, set to `intervals` for constructing confidence intervals (default), or to
|
|
`ellipse` for constructing an ellipse in the probability simplex, or to `ellipse-clr` for
|
|
constructing an ellipse in the Centered-Log Ratio (CLR) unconstrained space.
|
|
:param prior: an array-list with the alpha parameters of a Dirichlet prior, or the string 'uniform'
|
|
for a uniform, uninformative prior (default)
|
|
:param verbose: bool, whether to display progress bar
|
|
"""
|
|
def __init__(self,
|
|
classifier: BaseEstimator=None,
|
|
fit_classifier=True,
|
|
val_split: int = 5,
|
|
kernel='gaussian',
|
|
bandwidth=0.1,
|
|
num_warmup: int = 500,
|
|
num_samples: int = 1_000,
|
|
mcmc_seed: int = 0,
|
|
confidence_level: float = 0.95,
|
|
region: str = 'intervals',
|
|
explore='simplex',
|
|
step_size=0.05,
|
|
temperature=1.,
|
|
engine='numpyro',
|
|
prior='uniform',
|
|
verbose: bool = False,
|
|
**kwargs):
|
|
|
|
if num_warmup <= 0:
|
|
raise ValueError(f'parameter {num_warmup=} must be a positive integer')
|
|
if num_samples <= 0:
|
|
raise ValueError(f'parameter {num_samples=} must be a positive integer')
|
|
assert explore in ['simplex', 'clr', 'ilr'], \
|
|
f'unexpected value for param {explore=}; valid ones are "simplex", "clr", and "ilr"'
|
|
assert ((isinstance(prior, str) and prior == 'uniform') or
|
|
(isinstance(prior, Iterable) and all(isinstance(v, Number) for v in prior))), \
|
|
f'wrong type for {prior=}; expected "uniform" or an array-like of real values'
|
|
# assert temperature>0., f'temperature must be >0'
|
|
assert engine in ['rw-mh', 'emcee', 'numpyro']
|
|
|
|
super().__init__(classifier, fit_classifier, val_split)
|
|
assert all(k.startswith('classifier__') for k in kwargs.keys()), 'unexpected kwargs; must start with "classifier__"'
|
|
self.classifier.set_params(**{k.replace('classifier__', ''):v for k,v in kwargs.items()}) # <- improve
|
|
|
|
self.bandwidth = KDEBase._check_bandwidth(bandwidth, kernel)
|
|
self.kernel = self._check_kernel(kernel)
|
|
self.num_warmup = num_warmup
|
|
self.num_samples = num_samples
|
|
self.mcmc_seed = mcmc_seed
|
|
self.confidence_level = confidence_level
|
|
self.region = region
|
|
self.explore = explore
|
|
self.step_size = step_size
|
|
self.temperature = temperature
|
|
self.engine = engine
|
|
self.prior = prior
|
|
self.verbose = verbose
|
|
|
|
def aggregation_fit(self, classif_predictions, labels):
|
|
self.mix_densities = self.get_mixture_components(classif_predictions, labels, self.classes_, self.bandwidth, self.kernel)
|
|
return self
|
|
|
|
def aggregate(self, classif_predictions: np.ndarray):
|
|
if self.engine == 'rw-mh':
|
|
if self.prior != 'uniform':
|
|
raise RuntimeError('prior is not yet implemented in rw-mh')
|
|
self.prevalence_samples = self._bayesian_kde(classif_predictions, init=None, verbose=self.verbose)
|
|
elif self.engine == 'emcee':
|
|
if self.prior != 'uniform':
|
|
raise RuntimeError('prior is not yet implemented in emcee')
|
|
self.prevalence_samples = self._bayesian_emcee(classif_predictions)
|
|
elif self.engine == 'numpyro':
|
|
self.ilr = ILRtransformation(jax_mode=True)
|
|
self.prevalence_samples = self._bayesian_numpyro(classif_predictions)
|
|
return self.prevalence_samples.mean(axis=0)
|
|
|
|
def predict_conf(self, instances, confidence_level=None) -> (np.ndarray, ConfidenceRegionABC):
|
|
if confidence_level is None:
|
|
confidence_level = self.confidence_level
|
|
classif_predictions = self.classify(instances)
|
|
point_estimate = self.aggregate(classif_predictions)
|
|
samples = self.prevalence_samples # available after calling "aggregate" function
|
|
region = WithConfidenceABC.construct_region(samples, confidence_level=confidence_level, method=self.region)
|
|
return point_estimate, region
|
|
|
|
def _bayesian_kde(self, X_probs, init=None, verbose=False):
|
|
"""
|
|
Bayes:
|
|
P(prev|data) = P(data|prev) P(prev) / P(data)
|
|
i.e.,
|
|
posterior = likelihood * prior / evidence
|
|
we assume the likelihood be:
|
|
prev @ [kde_i_likelihood(data) 1..i..n]
|
|
prior be uniform in simplex
|
|
"""
|
|
|
|
rng = np.random.default_rng(self.mcmc_seed)
|
|
kdes = self.mix_densities
|
|
test_densities = np.asarray([self.pdf(kde_i, X_probs, self.kernel) for kde_i in kdes])
|
|
|
|
def log_likelihood(prev, epsilon=1e-10):
|
|
test_likelihoods = prev @ test_densities
|
|
test_loglikelihood = np.log(test_likelihoods + epsilon)
|
|
return (1./self.temperature) * np.sum(test_loglikelihood)
|
|
|
|
# def log_prior(prev):
|
|
# todo: adapt to arbitrary prior knowledge (e.g., something around training prevalence)
|
|
# return 1/np.sum((prev-init)**2) # it is not 1 but we assume uniform, son anyway is an useless constant
|
|
|
|
# def log_prior(prev, alpha_scale=1000):
|
|
# alpha = np.array(init) * alpha_scale
|
|
# return dirichlet.logpdf(prev, alpha)
|
|
|
|
def log_prior(prev):
|
|
return 0
|
|
|
|
def sample_neighbour(prev, step_size):
|
|
# random-walk Metropolis-Hastings
|
|
d = len(prev)
|
|
neighbour = None
|
|
if self.explore=='simplex':
|
|
dir_noise = rng.normal(scale=step_size/np.sqrt(d), size=d)
|
|
neighbour = F.normalize_prevalence(prev + dir_noise, method='mapsimplex')
|
|
elif self.explore=='clr':
|
|
clr = CLRtransformation()
|
|
clr_point = clr(prev)
|
|
dir_noise = rng.normal(scale=step_size, size=d)
|
|
clr_neighbour = clr_point+dir_noise
|
|
neighbour = clr.inverse(clr_neighbour)
|
|
assert in_simplex(neighbour), 'wrong CLR transformation'
|
|
elif self.explore=='ilr':
|
|
ilr = ILRtransformation()
|
|
ilr_point = ilr(prev)
|
|
dir_noise = rng.normal(scale=step_size, size=d-1)
|
|
ilr_neighbour = ilr_point + dir_noise
|
|
neighbour = ilr.inverse(ilr_neighbour)
|
|
assert in_simplex(neighbour), 'wrong ILR transformation'
|
|
return neighbour
|
|
|
|
n_classes = X_probs.shape[1]
|
|
current_prev = F.uniform_prevalence(n_classes) if init is None else init
|
|
current_likelihood = log_likelihood(current_prev) + log_prior(current_prev)
|
|
|
|
# Metropolis-Hastings with adaptive rate
|
|
step_size = self.step_size
|
|
target_acceptance = 0.3
|
|
adapt_rate = 0.05
|
|
acceptance_history = []
|
|
|
|
samples = []
|
|
total_steps = self.num_samples + self.num_warmup
|
|
for i in tqdm(range(total_steps), total=total_steps, disable=not verbose):
|
|
proposed_prev = sample_neighbour(current_prev, step_size)
|
|
|
|
# probability of acceptance
|
|
proposed_likelihood = log_likelihood(proposed_prev) + log_prior(proposed_prev)
|
|
acceptance = proposed_likelihood - current_likelihood
|
|
|
|
# decide acceptance
|
|
accepted = np.log(rng.random()) < acceptance
|
|
if accepted:
|
|
current_prev = proposed_prev
|
|
current_likelihood = proposed_likelihood
|
|
|
|
samples.append(current_prev)
|
|
acceptance_history.append(1. if accepted else 0.)
|
|
|
|
# if i < self.num_warmup and i%10==0 and len(acceptance_history)>=100:
|
|
if i % 10 == 0 and len(acceptance_history) >= 100:
|
|
recent_accept_rate = np.mean(acceptance_history[-100:])
|
|
step_size *= np.exp(adapt_rate * (recent_accept_rate - target_acceptance))
|
|
# step_size = float(np.clip(step_size, min_step, max_step))
|
|
if i %100==0:
|
|
print(f'acceptance-rate={recent_accept_rate*100:.3f}%, step-size={step_size:.5f}')
|
|
|
|
# remove "warmup" initial iterations
|
|
samples = np.asarray(samples[self.num_warmup:])
|
|
return samples
|
|
|
|
def _bayesian_emcee(self, X_probs):
|
|
ndim = X_probs.shape[1]
|
|
nwalkers = 32
|
|
if nwalkers < (ndim*2):
|
|
nwalkers = ndim * 2 + 1
|
|
|
|
f = CLRtransformation()
|
|
|
|
def log_likelihood(unconstrained, test_densities, epsilon=1e-10):
|
|
prev = f.inverse(unconstrained)
|
|
test_likelihoods = prev @ test_densities
|
|
test_loglikelihood = np.log(test_likelihoods + epsilon)
|
|
return np.sum(test_loglikelihood)
|
|
|
|
kdes = self.mix_densities
|
|
test_densities = np.asarray([self.pdf(kde_i, X_probs, self.kernel) for kde_i in kdes])
|
|
# test_densities_unconstrained = [f(t) for t in test_densities]
|
|
|
|
# p0 = np.random.normal(nwalkers, ndim)
|
|
p0 = F.uniform_prevalence_sampling(ndim, nwalkers)
|
|
p0 = f(p0)
|
|
sampler = emcee.EnsembleSampler(nwalkers, ndim, log_likelihood, args=[test_densities])
|
|
|
|
state = sampler.run_mcmc(p0, self.num_warmup, skip_initial_state_check=True)
|
|
sampler.reset()
|
|
sampler.run_mcmc(state, self.num_samples, skip_initial_state_check=True)
|
|
samples = sampler.get_chain(flat=True)
|
|
samples = f.inverse(samples)
|
|
return samples
|
|
|
|
def _bayesian_numpyro(self, X_probs):
|
|
kdes = self.mix_densities
|
|
# test_densities = np.asarray(
|
|
# [self.pdf(kde_i, X_probs, self.kernel) for kde_i in kdes]
|
|
# )
|
|
test_log_densities = np.asarray(
|
|
[self.pdf(kde_i, X_probs, self.kernel, log_densities=True) for kde_i in kdes]
|
|
)
|
|
print(f'min={np.min(test_log_densities)}')
|
|
print(f'max={np.max(test_log_densities)}')
|
|
# import sys
|
|
# sys.exit(0)
|
|
|
|
n_classes = X_probs.shape[-1]
|
|
if isinstance(self.prior, str) and self.prior == 'uniform':
|
|
alpha = [1.] * n_classes
|
|
else:
|
|
alpha = self.prior
|
|
|
|
kernel = NUTS(self._numpyro_model)
|
|
mcmc = MCMC(
|
|
kernel,
|
|
num_warmup=self.num_warmup,
|
|
num_samples=self.num_samples,
|
|
num_chains=1,
|
|
progress_bar=self.verbose,
|
|
)
|
|
|
|
rng_key = jax.random.PRNGKey(self.mcmc_seed)
|
|
mcmc.run(rng_key, test_log_densities=test_log_densities, alpha=alpha)
|
|
|
|
samples_z = mcmc.get_samples()["z"]
|
|
|
|
# back to simplex
|
|
samples_prev = np.asarray(self.ilr.inverse(np.asarray(samples_z)))
|
|
|
|
return samples_prev
|
|
|
|
def _numpyro_model(self, test_log_densities, alpha):
|
|
"""
|
|
test_densities: shape (n_classes, n_instances,)
|
|
"""
|
|
n_classes = test_log_densities.shape[0]
|
|
|
|
# sample in unconstrained R^(n_classes-1)
|
|
z = numpyro.sample(
|
|
"z",
|
|
dist.Normal(0.0, 1.0).expand([n_classes - 1])
|
|
)
|
|
|
|
prev = self.ilr.inverse(z) # simplex, shape (n_classes,)
|
|
|
|
# for numerical stability:
|
|
# eps = 1e-10
|
|
# prev_safe = jnp.clip(prev, eps, 1.0)
|
|
# prev_safe = prev_safe / jnp.sum(prev_safe)
|
|
# prev = prev_safe
|
|
|
|
# from jax import debug
|
|
# debug.print("prev = {}", prev)
|
|
# numpyro.factor(
|
|
# "check_prev",
|
|
# jnp.where(jnp.all(prev > 0), 0.0, -jnp.inf)
|
|
# )
|
|
|
|
# prior
|
|
if alpha is not None:
|
|
alpha = jnp.asarray(alpha)
|
|
numpyro.factor(
|
|
'log_prior', dist.Dirichlet(alpha).log_prob(prev)
|
|
)
|
|
# if alpha is None, then this corresponds to a weak logistic-normal prior
|
|
|
|
# likelihood
|
|
# test_densities = jnp.array(test_densities)
|
|
# likelihoods = jnp.dot(prev, test_densities)
|
|
# likelihoods = jnp.clip(likelihoods, 1e-12, jnp.inf) # numerical stability
|
|
# numpyro.factor(
|
|
# "loglik", (1.0 / self.temperature) * jnp.sum(jnp.log(likelihoods))
|
|
# )
|
|
|
|
log_likelihood = jnp.sum(
|
|
logsumexp(jnp.log(prev)[:, None] + test_log_densities, axis=0)
|
|
)
|
|
numpyro.factor(
|
|
"loglik", (1.0 / self.temperature) * log_likelihood
|
|
)
|
|
|
|
|
|
|
|
|