329 lines
11 KiB
Python
329 lines
11 KiB
Python
import jax.numpy as jnp
|
|
import numpy as np
|
|
import numpyro
|
|
import numpyro.distributions as dist
|
|
from numpyro.infer import MCMC, NUTS, HMC
|
|
import jax.random as random
|
|
from sklearn.base import BaseEstimator
|
|
from jax.scipy.special import logsumexp
|
|
from BayesianKDEy.commons import ILRtransformation
|
|
from quapy.method.aggregative import AggregativeSoftQuantifier
|
|
from quapy.method.confidence import WithConfidenceABC, ConfidenceRegionABC
|
|
import quapy.functional as F
|
|
|
|
|
|
class BayesianMAPLS(AggregativeSoftQuantifier, WithConfidenceABC):
|
|
"""
|
|
|
|
:param classifier:
|
|
:param fit_classifier:
|
|
:param val_split:
|
|
:param exact_train_prev: set to True (default) for using the true training prevalence as the initial
|
|
observation; set to False for computing the training prevalence as an estimate of it, i.e., as the
|
|
expected value of the posterior probabilities of the training instances.
|
|
:param num_samples:
|
|
:param mcmc_seed:
|
|
:param confidence_level:
|
|
:param region:
|
|
"""
|
|
|
|
def __init__(self,
|
|
classifier: BaseEstimator = None,
|
|
fit_classifier=True,
|
|
val_split: int = 5,
|
|
exact_train_prev=True,
|
|
num_warmup: int = 500,
|
|
num_samples: int = 1_000,
|
|
mcmc_seed: int = 0,
|
|
confidence_level: float = 0.95,
|
|
region: str = 'intervals',
|
|
temperature=1.,
|
|
prior='uniform',
|
|
mapls_chain_init=True,
|
|
verbose=False
|
|
):
|
|
|
|
if num_samples <= 0:
|
|
raise ValueError(f'parameter {num_samples=} must be a positive integer')
|
|
super().__init__(classifier, fit_classifier, val_split)
|
|
self.exact_train_prev = exact_train_prev
|
|
self.num_warmup = num_warmup
|
|
self.num_samples = num_samples
|
|
self.mcmc_seed = mcmc_seed
|
|
self.confidence_level = confidence_level
|
|
self.region = region
|
|
self.temperature = temperature
|
|
self.prior = prior
|
|
self.mapls_chain_init = mapls_chain_init
|
|
self.verbose = verbose
|
|
|
|
def aggregation_fit(self, classif_predictions, labels):
|
|
self.train_post = classif_predictions
|
|
if self.exact_train_prev:
|
|
self.train_prevalence = F.prevalence_from_labels(labels, classes=self.classes_)
|
|
else:
|
|
self.train_prevalence = F.prevalence_from_probabilities(classif_predictions)
|
|
self.ilr = ILRtransformation(jax_mode=True)
|
|
return self
|
|
|
|
def aggregate(self, classif_predictions: np.ndarray):
|
|
n_test, n_classes = classif_predictions.shape
|
|
|
|
pi_star, lam = mapls(
|
|
self.train_post,
|
|
test_probs=classif_predictions,
|
|
pz=self.train_prevalence,
|
|
return_lambda=True
|
|
)
|
|
|
|
# pi_star: MAP in simplex shape (n_classes,) and convert to ILR space
|
|
z0 = self.ilr(pi_star)
|
|
|
|
if self.prior == 'uniform':
|
|
alpha = [1.] * n_classes
|
|
elif self.prior == 'map':
|
|
alpha_0 = alpha0_from_lamda(lam, n_test=n_test, n_classes=n_classes)
|
|
alpha = [alpha_0] * n_classes
|
|
elif self.prior == 'map2':
|
|
lam2 = get_lamda(
|
|
test_probs=classif_predictions,
|
|
pz=self.train_prevalence,
|
|
q_prior=pi_star,
|
|
dvg=kl_div
|
|
)
|
|
alpha_0 = alpha0_from_lamda(lam2, n_test=n_test, n_classes=n_classes)
|
|
alpha = [alpha_0] * n_classes
|
|
else:
|
|
alpha = self.prior
|
|
|
|
kernel = NUTS(self.model)
|
|
mcmc = MCMC(
|
|
kernel,
|
|
num_warmup=self.num_warmup,
|
|
num_samples=self.num_samples,
|
|
num_chains=1,
|
|
progress_bar=self.verbose
|
|
)
|
|
|
|
mcmc.run(
|
|
random.PRNGKey(self.mcmc_seed),
|
|
test_posteriors=classif_predictions,
|
|
alpha=alpha,
|
|
init_params={"z": z0} if self.mapls_chain_init else None
|
|
)
|
|
|
|
samples = mcmc.get_samples()["z"]
|
|
self.prevalence_samples = self.ilr.inverse(samples)
|
|
return self.prevalence_samples.mean(axis=0)
|
|
|
|
def predict_conf(self, instances, confidence_level=None) -> (np.ndarray, ConfidenceRegionABC):
|
|
if confidence_level is None:
|
|
confidence_level = self.confidence_level
|
|
classif_predictions = self.classify(instances)
|
|
point_estimate = self.aggregate(classif_predictions)
|
|
samples = self.prevalence_samples # available after calling "aggregate" function
|
|
region = WithConfidenceABC.construct_region(samples, confidence_level=confidence_level, method=self.region)
|
|
return point_estimate, region
|
|
|
|
def log_likelihood(self, test_classif, test_prev, train_prev):
|
|
# n_test = test_classif.shape[0]
|
|
log_w = jnp.log(test_prev) - jnp.log(train_prev)
|
|
# return (1/n_test) * jnp.sum(
|
|
# logsumexp(jnp.log(test_classif) + log_w, axis=-1)
|
|
# )
|
|
return jnp.sum(
|
|
logsumexp(jnp.log(test_classif) + log_w, axis=-1)
|
|
)
|
|
|
|
def model(self, test_posteriors, alpha):
|
|
test_posteriors = jnp.array(test_posteriors)
|
|
n_classes = test_posteriors.shape[1]
|
|
|
|
# prior in ILR
|
|
z = numpyro.sample(
|
|
"z",
|
|
dist.Normal(jnp.zeros(n_classes-1), 1.0)
|
|
)
|
|
|
|
# back to simplex
|
|
prev = self.ilr.inverse(z)
|
|
train_prev = jnp.array(self.train_prevalence)
|
|
|
|
# prior
|
|
alpha = jnp.array(alpha)
|
|
numpyro.factor(
|
|
"dirichlet_prior",
|
|
dist.Dirichlet(alpha).log_prob(prev)
|
|
)
|
|
|
|
# Likelihood
|
|
numpyro.factor(
|
|
"likelihood",
|
|
(1.0 / self.temperature) * self.log_likelihood(test_posteriors, test_prev=prev, train_prev=train_prev)
|
|
)
|
|
|
|
|
|
# adapted from https://github.com/ChangkunYe/MAPLS/blob/main/label_shift/mapls.py
|
|
|
|
def mapls(train_probs: np.ndarray,
|
|
test_probs: np.ndarray,
|
|
pz: np.ndarray,
|
|
qy_mode: str = 'soft',
|
|
max_iter: int = 100,
|
|
init_mode: str = 'identical',
|
|
lam: float = None,
|
|
dvg_name='kl',
|
|
return_lambda=False
|
|
):
|
|
r"""
|
|
Implementation of Maximum A Posteriori Label Shift,
|
|
for Unknown target label distribution estimation
|
|
|
|
Given source domain P(Y_s=i|X_s=x) = f(x) and P(Y_s=i),
|
|
estimate targe domain P(Y_t=i) on test set
|
|
|
|
"""
|
|
# Sanity Check
|
|
cls_num = len(pz)
|
|
assert test_probs.shape[-1] == cls_num
|
|
if type(max_iter) != int or max_iter < 0:
|
|
raise Exception('max_iter should be a positive integer, not ' + str(max_iter))
|
|
|
|
# Setup d(p,q) measure
|
|
if dvg_name == 'kl':
|
|
dvg = kl_div
|
|
elif dvg_name == 'js':
|
|
dvg = js_div
|
|
else:
|
|
raise Exception('Unsupported distribution distance measure, expect kl or js.')
|
|
|
|
# Set Prior of Target Label Distribution
|
|
q_prior = np.ones(cls_num) / cls_num
|
|
# q_prior = pz.copy()
|
|
|
|
# Lambda estimation-------------------------------------------------------#
|
|
if lam is None:
|
|
# logging.info('Data shape: %s, %s' % (str(train_probs.shape), str(test_probs.shape)))
|
|
# logging.info('Divergence type is %s' % (dvg))
|
|
lam = get_lamda(test_probs, pz, q_prior, dvg=dvg, max_iter=max_iter)
|
|
# logging.info('Estimated lambda value is %.4f' % lam)
|
|
# else:
|
|
# logging.info('Assigned lambda is %.4f' % lam)
|
|
|
|
# EM Algorithm Computation
|
|
qz = mapls_EM(test_probs, pz, lam, q_prior, cls_num,
|
|
init_mode=init_mode, max_iter=max_iter, qy_mode=qy_mode)
|
|
|
|
if return_lambda:
|
|
return qz, lam
|
|
else:
|
|
return qz
|
|
|
|
|
|
def mapls_EM(probs, pz, lam, q_prior, cls_num, init_mode='identical', max_iter=100, qy_mode='soft'):
|
|
# Normalize Source Label Distribution pz
|
|
pz = np.array(pz) / np.sum(pz)
|
|
# Initialize Target Label Distribution qz
|
|
if init_mode == 'uniform':
|
|
qz = np.ones(cls_num) / cls_num
|
|
elif init_mode == 'identical':
|
|
qz = pz.copy()
|
|
else:
|
|
raise ValueError('init_mode should be either "uniform" or "identical"')
|
|
|
|
# Initialize w
|
|
w = (np.array(qz) / np.array(pz))
|
|
# EM algorithm with MAP estimation----------------------------------------#
|
|
for i in range(max_iter):
|
|
# print('w shape ', w.shape)
|
|
|
|
# E-Step--------------------------------------------------------------#
|
|
mapls_probs = normalized(probs * w, axis=-1, order=1)
|
|
|
|
# M-Step--------------------------------------------------------------#
|
|
if qy_mode == 'hard':
|
|
pred = np.argmax(mapls_probs, axis=-1)
|
|
qz_new = np.bincount(pred.reshape(-1), minlength=cls_num)
|
|
elif qy_mode == 'soft':
|
|
qz_new = np.mean(mapls_probs, axis=0)
|
|
# elif qy_mode == 'topk':
|
|
# qz_new = Topk_qy(mapls_probs, cls_num, topk_ratio=0.9, head=0)
|
|
else:
|
|
raise Exception('mapls mode should be either "soft" or "hard". ')
|
|
# print(np.shape(pc_probs), np.shape(pred), np.shape(cls_num_list_t))
|
|
|
|
# Update w with MAP estimation of Target Label Distribution qz
|
|
# qz = (qz_new + alpha) / (N + np.sum(alpha))
|
|
qz = lam * qz_new + (1 - lam) * q_prior
|
|
qz /= qz.sum()
|
|
w = qz / pz
|
|
|
|
return qz
|
|
|
|
|
|
def get_lamda(test_probs, pz, q_prior, dvg, max_iter=50):
|
|
K = len(pz)
|
|
|
|
# MLLS estimation of source and target domain label distribution
|
|
qz_pred = mapls_EM(test_probs, pz, 1, 0, K, max_iter=max_iter)
|
|
|
|
TU_div = dvg(qz_pred, q_prior)
|
|
TS_div = dvg(qz_pred, pz)
|
|
SU_div = dvg(pz, q_prior)
|
|
# logging.info('weights are, TU_div %.4f, TS_div %.4f, SU_div %.4f' % (TU_div, TS_div, SU_div))
|
|
|
|
SU_conf = 1 - lam_forward(SU_div, lam_inv(dpq=0.5, lam=0.2))
|
|
TU_conf = lam_forward(TU_div, lam_inv(dpq=0.5, lam=SU_conf))
|
|
TS_conf = lam_forward(TS_div, lam_inv(dpq=0.5, lam=SU_conf))
|
|
# logging.info('weights are, unviform_weight %.4f, differ_weight %.4f, regularize weight %.4f'
|
|
# % (TU_conf, TS_conf, SU_conf))
|
|
|
|
confs = np.array([TU_conf, 1 - TS_conf])
|
|
w = np.array([0.9, 0.1])
|
|
lam = np.sum(w * confs)
|
|
|
|
# logging.info('Estimated lambda is: %.4f', lam)
|
|
|
|
return lam
|
|
|
|
|
|
def lam_inv(dpq, lam):
|
|
return (1 / (1 - lam) - 1) / dpq
|
|
|
|
|
|
def lam_forward(dpq, gamma):
|
|
return gamma * dpq / (1 + gamma * dpq)
|
|
|
|
|
|
# def kl_div(p, q):
|
|
# p = np.asarray(p, dtype=np.float32)
|
|
# q = np.asarray(q + 1e-8, dtype=np.float32)
|
|
#
|
|
# return np.sum(np.where(p != 0, p * np.log(p / q), 0))
|
|
|
|
|
|
def kl_div(p, q, eps=1e-12):
|
|
p = np.asarray(p, dtype=float)
|
|
q = np.asarray(q, dtype=float)
|
|
|
|
mask = p > 0
|
|
return np.sum(p[mask] * np.log(p[mask] / (q[mask] + eps)))
|
|
|
|
|
|
def js_div(p, q):
|
|
assert (np.abs(np.sum(p) - 1) < 1e-6) and (np.abs(np.sum(q) - 1) < 1e-6)
|
|
m = (p + q) / 2
|
|
return kl_div(p, m) / 2 + kl_div(q, m) / 2
|
|
|
|
|
|
def normalized(a, axis=-1, order=2):
|
|
r"""
|
|
Prediction Normalization
|
|
"""
|
|
l2 = np.atleast_1d(np.linalg.norm(a, order, axis))
|
|
l2[l2 == 0] = 1
|
|
return a / np.expand_dims(l2, axis)
|
|
|
|
|
|
def alpha0_from_lamda(lam, n_test, n_classes):
|
|
return 1+n_test*(1-lam)/(lam*n_classes) |