QuaPy/BayesianKDEy/_bayesian_mapls.py

329 lines
11 KiB
Python

import jax.numpy as jnp
import numpy as np
import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS, HMC
import jax.random as random
from sklearn.base import BaseEstimator
from jax.scipy.special import logsumexp
from BayesianKDEy.commons import ILRtransformation
from quapy.method.aggregative import AggregativeSoftQuantifier
from quapy.method.confidence import WithConfidenceABC, ConfidenceRegionABC
import quapy.functional as F
class BayesianMAPLS(AggregativeSoftQuantifier, WithConfidenceABC):
"""
:param classifier:
:param fit_classifier:
:param val_split:
:param exact_train_prev: set to True (default) for using the true training prevalence as the initial
observation; set to False for computing the training prevalence as an estimate of it, i.e., as the
expected value of the posterior probabilities of the training instances.
:param num_samples:
:param mcmc_seed:
:param confidence_level:
:param region:
"""
def __init__(self,
classifier: BaseEstimator = None,
fit_classifier=True,
val_split: int = 5,
exact_train_prev=True,
num_warmup: int = 500,
num_samples: int = 1_000,
mcmc_seed: int = 0,
confidence_level: float = 0.95,
region: str = 'intervals',
temperature=1.,
prior='uniform',
mapls_chain_init=True,
verbose=False
):
if num_samples <= 0:
raise ValueError(f'parameter {num_samples=} must be a positive integer')
super().__init__(classifier, fit_classifier, val_split)
self.exact_train_prev = exact_train_prev
self.num_warmup = num_warmup
self.num_samples = num_samples
self.mcmc_seed = mcmc_seed
self.confidence_level = confidence_level
self.region = region
self.temperature = temperature
self.prior = prior
self.mapls_chain_init = mapls_chain_init
self.verbose = verbose
def aggregation_fit(self, classif_predictions, labels):
self.train_post = classif_predictions
if self.exact_train_prev:
self.train_prevalence = F.prevalence_from_labels(labels, classes=self.classes_)
else:
self.train_prevalence = F.prevalence_from_probabilities(classif_predictions)
self.ilr = ILRtransformation(jax_mode=True)
return self
def aggregate(self, classif_predictions: np.ndarray):
n_test, n_classes = classif_predictions.shape
pi_star, lam = mapls(
self.train_post,
test_probs=classif_predictions,
pz=self.train_prevalence,
return_lambda=True
)
# pi_star: MAP in simplex shape (n_classes,) and convert to ILR space
z0 = self.ilr(pi_star)
if self.prior == 'uniform':
alpha = [1.] * n_classes
elif self.prior == 'map':
alpha_0 = alpha0_from_lamda(lam, n_test=n_test, n_classes=n_classes)
alpha = [alpha_0] * n_classes
elif self.prior == 'map2':
lam2 = get_lamda(
test_probs=classif_predictions,
pz=self.train_prevalence,
q_prior=pi_star,
dvg=kl_div
)
alpha_0 = alpha0_from_lamda(lam2, n_test=n_test, n_classes=n_classes)
alpha = [alpha_0] * n_classes
else:
alpha = self.prior
kernel = NUTS(self.model)
mcmc = MCMC(
kernel,
num_warmup=self.num_warmup,
num_samples=self.num_samples,
num_chains=1,
progress_bar=self.verbose
)
mcmc.run(
random.PRNGKey(self.mcmc_seed),
test_posteriors=classif_predictions,
alpha=alpha,
init_params={"z": z0} if self.mapls_chain_init else None
)
samples = mcmc.get_samples()["z"]
self.prevalence_samples = self.ilr.inverse(samples)
return self.prevalence_samples.mean(axis=0)
def predict_conf(self, instances, confidence_level=None) -> (np.ndarray, ConfidenceRegionABC):
if confidence_level is None:
confidence_level = self.confidence_level
classif_predictions = self.classify(instances)
point_estimate = self.aggregate(classif_predictions)
samples = self.prevalence_samples # available after calling "aggregate" function
region = WithConfidenceABC.construct_region(samples, confidence_level=confidence_level, method=self.region)
return point_estimate, region
def log_likelihood(self, test_classif, test_prev, train_prev):
# n_test = test_classif.shape[0]
log_w = jnp.log(test_prev) - jnp.log(train_prev)
# return (1/n_test) * jnp.sum(
# logsumexp(jnp.log(test_classif) + log_w, axis=-1)
# )
return jnp.sum(
logsumexp(jnp.log(test_classif) + log_w, axis=-1)
)
def model(self, test_posteriors, alpha):
test_posteriors = jnp.array(test_posteriors)
n_classes = test_posteriors.shape[1]
# prior in ILR
z = numpyro.sample(
"z",
dist.Normal(jnp.zeros(n_classes-1), 1.0)
)
# back to simplex
prev = self.ilr.inverse(z)
train_prev = jnp.array(self.train_prevalence)
# prior
alpha = jnp.array(alpha)
numpyro.factor(
"dirichlet_prior",
dist.Dirichlet(alpha).log_prob(prev)
)
# Likelihood
numpyro.factor(
"likelihood",
(1.0 / self.temperature) * self.log_likelihood(test_posteriors, test_prev=prev, train_prev=train_prev)
)
# adapted from https://github.com/ChangkunYe/MAPLS/blob/main/label_shift/mapls.py
def mapls(train_probs: np.ndarray,
test_probs: np.ndarray,
pz: np.ndarray,
qy_mode: str = 'soft',
max_iter: int = 100,
init_mode: str = 'identical',
lam: float = None,
dvg_name='kl',
return_lambda=False
):
r"""
Implementation of Maximum A Posteriori Label Shift,
for Unknown target label distribution estimation
Given source domain P(Y_s=i|X_s=x) = f(x) and P(Y_s=i),
estimate targe domain P(Y_t=i) on test set
"""
# Sanity Check
cls_num = len(pz)
assert test_probs.shape[-1] == cls_num
if type(max_iter) != int or max_iter < 0:
raise Exception('max_iter should be a positive integer, not ' + str(max_iter))
# Setup d(p,q) measure
if dvg_name == 'kl':
dvg = kl_div
elif dvg_name == 'js':
dvg = js_div
else:
raise Exception('Unsupported distribution distance measure, expect kl or js.')
# Set Prior of Target Label Distribution
q_prior = np.ones(cls_num) / cls_num
# q_prior = pz.copy()
# Lambda estimation-------------------------------------------------------#
if lam is None:
# logging.info('Data shape: %s, %s' % (str(train_probs.shape), str(test_probs.shape)))
# logging.info('Divergence type is %s' % (dvg))
lam = get_lamda(test_probs, pz, q_prior, dvg=dvg, max_iter=max_iter)
# logging.info('Estimated lambda value is %.4f' % lam)
# else:
# logging.info('Assigned lambda is %.4f' % lam)
# EM Algorithm Computation
qz = mapls_EM(test_probs, pz, lam, q_prior, cls_num,
init_mode=init_mode, max_iter=max_iter, qy_mode=qy_mode)
if return_lambda:
return qz, lam
else:
return qz
def mapls_EM(probs, pz, lam, q_prior, cls_num, init_mode='identical', max_iter=100, qy_mode='soft'):
# Normalize Source Label Distribution pz
pz = np.array(pz) / np.sum(pz)
# Initialize Target Label Distribution qz
if init_mode == 'uniform':
qz = np.ones(cls_num) / cls_num
elif init_mode == 'identical':
qz = pz.copy()
else:
raise ValueError('init_mode should be either "uniform" or "identical"')
# Initialize w
w = (np.array(qz) / np.array(pz))
# EM algorithm with MAP estimation----------------------------------------#
for i in range(max_iter):
# print('w shape ', w.shape)
# E-Step--------------------------------------------------------------#
mapls_probs = normalized(probs * w, axis=-1, order=1)
# M-Step--------------------------------------------------------------#
if qy_mode == 'hard':
pred = np.argmax(mapls_probs, axis=-1)
qz_new = np.bincount(pred.reshape(-1), minlength=cls_num)
elif qy_mode == 'soft':
qz_new = np.mean(mapls_probs, axis=0)
# elif qy_mode == 'topk':
# qz_new = Topk_qy(mapls_probs, cls_num, topk_ratio=0.9, head=0)
else:
raise Exception('mapls mode should be either "soft" or "hard". ')
# print(np.shape(pc_probs), np.shape(pred), np.shape(cls_num_list_t))
# Update w with MAP estimation of Target Label Distribution qz
# qz = (qz_new + alpha) / (N + np.sum(alpha))
qz = lam * qz_new + (1 - lam) * q_prior
qz /= qz.sum()
w = qz / pz
return qz
def get_lamda(test_probs, pz, q_prior, dvg, max_iter=50):
K = len(pz)
# MLLS estimation of source and target domain label distribution
qz_pred = mapls_EM(test_probs, pz, 1, 0, K, max_iter=max_iter)
TU_div = dvg(qz_pred, q_prior)
TS_div = dvg(qz_pred, pz)
SU_div = dvg(pz, q_prior)
# logging.info('weights are, TU_div %.4f, TS_div %.4f, SU_div %.4f' % (TU_div, TS_div, SU_div))
SU_conf = 1 - lam_forward(SU_div, lam_inv(dpq=0.5, lam=0.2))
TU_conf = lam_forward(TU_div, lam_inv(dpq=0.5, lam=SU_conf))
TS_conf = lam_forward(TS_div, lam_inv(dpq=0.5, lam=SU_conf))
# logging.info('weights are, unviform_weight %.4f, differ_weight %.4f, regularize weight %.4f'
# % (TU_conf, TS_conf, SU_conf))
confs = np.array([TU_conf, 1 - TS_conf])
w = np.array([0.9, 0.1])
lam = np.sum(w * confs)
# logging.info('Estimated lambda is: %.4f', lam)
return lam
def lam_inv(dpq, lam):
return (1 / (1 - lam) - 1) / dpq
def lam_forward(dpq, gamma):
return gamma * dpq / (1 + gamma * dpq)
# def kl_div(p, q):
# p = np.asarray(p, dtype=np.float32)
# q = np.asarray(q + 1e-8, dtype=np.float32)
#
# return np.sum(np.where(p != 0, p * np.log(p / q), 0))
def kl_div(p, q, eps=1e-12):
p = np.asarray(p, dtype=float)
q = np.asarray(q, dtype=float)
mask = p > 0
return np.sum(p[mask] * np.log(p[mask] / (q[mask] + eps)))
def js_div(p, q):
assert (np.abs(np.sum(p) - 1) < 1e-6) and (np.abs(np.sum(q) - 1) < 1e-6)
m = (p + q) / 2
return kl_div(p, m) / 2 + kl_div(q, m) / 2
def normalized(a, axis=-1, order=2):
r"""
Prediction Normalization
"""
l2 = np.atleast_1d(np.linalg.norm(a, order, axis))
l2[l2 == 0] = 1
return a / np.expand_dims(l2, axis)
def alpha0_from_lamda(lam, n_test, n_classes):
return 1+n_test*(1-lam)/(lam*n_classes)