import jax.numpy as jnp import numpy as np import numpyro import numpyro.distributions as dist from numpyro.infer import MCMC, NUTS, HMC import jax.random as random from sklearn.base import BaseEstimator from jax.scipy.special import logsumexp from BayesianKDEy.commons import ILRtransformation from quapy.method.aggregative import AggregativeSoftQuantifier from quapy.method.confidence import WithConfidenceABC, ConfidenceRegionABC import quapy.functional as F class BayesianMAPLS(AggregativeSoftQuantifier, WithConfidenceABC): """ :param classifier: :param fit_classifier: :param val_split: :param exact_train_prev: set to True (default) for using the true training prevalence as the initial observation; set to False for computing the training prevalence as an estimate of it, i.e., as the expected value of the posterior probabilities of the training instances. :param num_samples: :param mcmc_seed: :param confidence_level: :param region: """ def __init__(self, classifier: BaseEstimator = None, fit_classifier=True, val_split: int = 5, exact_train_prev=True, num_warmup: int = 500, num_samples: int = 1_000, mcmc_seed: int = 0, confidence_level: float = 0.95, region: str = 'intervals', temperature=1., prior='uniform', verbose=False ): if num_samples <= 0: raise ValueError(f'parameter {num_samples=} must be a positive integer') super().__init__(classifier, fit_classifier, val_split) self.exact_train_prev = exact_train_prev self.num_warmup = num_warmup self.num_samples = num_samples self.mcmc_seed = mcmc_seed self.confidence_level = confidence_level self.region = region self.temperature = temperature self.prior = prior self.verbose = verbose def aggregation_fit(self, classif_predictions, labels): self.train_post = classif_predictions if self.exact_train_prev: self.train_prevalence = F.prevalence_from_labels(labels, classes=self.classes_) else: self.train_prevalence = F.prevalence_from_probabilities(classif_predictions) self.ilr = ILRtransformation(jax_mode=True) return self def aggregate(self, classif_predictions: np.ndarray): n_test, n_classes = classif_predictions.shape pi_star, lam = mapls( self.train_post, test_probs=classif_predictions, pz=self.train_prevalence, return_lambda=True ) # pi_star: MAP in simplex (shape: [K]), convert to ILR space z0 = self.ilr(pi_star) if self.prior == 'uniform': alpha = [1.] * n_classes elif self.prior == 'map': alpha_0 = alpha0_from_lamda(lam, n_test=n_test, n_classes=n_classes) alpha = [alpha_0] * n_classes elif self.prior == 'map2': lam2 = get_lamda( test_probs=classif_predictions, pz=self.train_prevalence, q_prior=pi_star, dvg=kl_div ) alpha_0 = alpha0_from_lamda(lam2, n_test=n_test, n_classes=n_classes) alpha = [alpha_0] * n_classes else: alpha = self.prior kernel = NUTS(self.model) mcmc = MCMC( kernel, num_warmup=self.num_warmup, num_samples=self.num_samples, num_chains=1, progress_bar=self.verbose ) mcmc.run( random.PRNGKey(self.mcmc_seed), test_posteriors=classif_predictions, alpha=alpha, init_params={"z": z0} ) samples = mcmc.get_samples()["z"] self.prevalence_samples = self.ilr.inverse(samples) return self.prevalence_samples.mean(axis=0) def predict_conf(self, instances, confidence_level=None) -> (np.ndarray, ConfidenceRegionABC): if confidence_level is None: confidence_level = self.confidence_level classif_predictions = self.classify(instances) point_estimate = self.aggregate(classif_predictions) samples = self.prevalence_samples # available after calling "aggregate" function region = WithConfidenceABC.construct_region(samples, confidence_level=confidence_level, method=self.region) return point_estimate, region def log_likelihood(self, test_classif, test_prev, train_prev): # n_test = test_classif.shape[0] log_w = jnp.log(test_prev) - jnp.log(train_prev) # return (1/n_test) * jnp.sum( # logsumexp(jnp.log(test_classif) + log_w, axis=-1) # ) return jnp.sum( logsumexp(jnp.log(test_classif) + log_w, axis=-1) ) def model(self, test_posteriors, alpha): test_posteriors = jnp.array(test_posteriors) n_classes = test_posteriors.shape[1] # prior in ILR z = numpyro.sample( "z", dist.Normal(jnp.zeros(n_classes-1), 1.0) ) # back to simplex prev = self.ilr.inverse(z) train_prev = jnp.array(self.train_prevalence) # prior alpha = jnp.array(alpha) numpyro.factor( "dirichlet_prior", dist.Dirichlet(alpha).log_prob(prev) ) # Likelihood numpyro.factor( "likelihood", (1.0 / self.temperature) * self.log_likelihood(test_posteriors, test_prev=prev, train_prev=train_prev) ) # adapted from https://github.com/ChangkunYe/MAPLS/blob/main/label_shift/mapls.py def mapls(train_probs: np.ndarray, test_probs: np.ndarray, pz: np.ndarray, qy_mode: str = 'soft', max_iter: int = 100, init_mode: str = 'identical', lam: float = None, dvg_name='kl', return_lambda=False ): r""" Implementation of Maximum A Posteriori Label Shift, for Unknown target label distribution estimation Given source domain P(Y_s=i|X_s=x) = f(x) and P(Y_s=i), estimate targe domain P(Y_t=i) on test set """ # Sanity Check cls_num = len(pz) assert test_probs.shape[-1] == cls_num if type(max_iter) != int or max_iter < 0: raise Exception('max_iter should be a positive integer, not ' + str(max_iter)) # Setup d(p,q) measure if dvg_name == 'kl': dvg = kl_div elif dvg_name == 'js': dvg = js_div else: raise Exception('Unsupported distribution distance measure, expect kl or js.') # Set Prior of Target Label Distribution q_prior = np.ones(cls_num) / cls_num # q_prior = pz.copy() # Lambda estimation-------------------------------------------------------# if lam is None: # logging.info('Data shape: %s, %s' % (str(train_probs.shape), str(test_probs.shape))) # logging.info('Divergence type is %s' % (dvg)) lam = get_lamda(test_probs, pz, q_prior, dvg=dvg, max_iter=max_iter) # logging.info('Estimated lambda value is %.4f' % lam) # else: # logging.info('Assigned lambda is %.4f' % lam) # EM Algorithm Computation qz = mapls_EM(test_probs, pz, lam, q_prior, cls_num, init_mode=init_mode, max_iter=max_iter, qy_mode=qy_mode) if return_lambda: return qz, lam else: return qz def mapls_EM(probs, pz, lam, q_prior, cls_num, init_mode='identical', max_iter=100, qy_mode='soft'): # Normalize Source Label Distribution pz pz = np.array(pz) / np.sum(pz) # Initialize Target Label Distribution qz if init_mode == 'uniform': qz = np.ones(cls_num) / cls_num elif init_mode == 'identical': qz = pz.copy() else: raise ValueError('init_mode should be either "uniform" or "identical"') # Initialize w w = (np.array(qz) / np.array(pz)) # EM algorithm with MAP estimation----------------------------------------# for i in range(max_iter): # print('w shape ', w.shape) # E-Step--------------------------------------------------------------# mapls_probs = normalized(probs * w, axis=-1, order=1) # M-Step--------------------------------------------------------------# if qy_mode == 'hard': pred = np.argmax(mapls_probs, axis=-1) qz_new = np.bincount(pred.reshape(-1), minlength=cls_num) elif qy_mode == 'soft': qz_new = np.mean(mapls_probs, axis=0) # elif qy_mode == 'topk': # qz_new = Topk_qy(mapls_probs, cls_num, topk_ratio=0.9, head=0) else: raise Exception('mapls mode should be either "soft" or "hard". ') # print(np.shape(pc_probs), np.shape(pred), np.shape(cls_num_list_t)) # Update w with MAP estimation of Target Label Distribution qz # qz = (qz_new + alpha) / (N + np.sum(alpha)) qz = lam * qz_new + (1 - lam) * q_prior qz /= qz.sum() w = qz / pz return qz def get_lamda(test_probs, pz, q_prior, dvg, max_iter=50): K = len(pz) # MLLS estimation of source and target domain label distribution qz_pred = mapls_EM(test_probs, pz, 1, 0, K, max_iter=max_iter) TU_div = dvg(qz_pred, q_prior) TS_div = dvg(qz_pred, pz) SU_div = dvg(pz, q_prior) # logging.info('weights are, TU_div %.4f, TS_div %.4f, SU_div %.4f' % (TU_div, TS_div, SU_div)) SU_conf = 1 - lam_forward(SU_div, lam_inv(dpq=0.5, lam=0.2)) TU_conf = lam_forward(TU_div, lam_inv(dpq=0.5, lam=SU_conf)) TS_conf = lam_forward(TS_div, lam_inv(dpq=0.5, lam=SU_conf)) # logging.info('weights are, unviform_weight %.4f, differ_weight %.4f, regularize weight %.4f' # % (TU_conf, TS_conf, SU_conf)) confs = np.array([TU_conf, 1 - TS_conf]) w = np.array([0.9, 0.1]) lam = np.sum(w * confs) # logging.info('Estimated lambda is: %.4f', lam) return lam def lam_inv(dpq, lam): return (1 / (1 - lam) - 1) / dpq def lam_forward(dpq, gamma): return gamma * dpq / (1 + gamma * dpq) # def kl_div(p, q): # p = np.asarray(p, dtype=np.float32) # q = np.asarray(q + 1e-8, dtype=np.float32) # # return np.sum(np.where(p != 0, p * np.log(p / q), 0)) def kl_div(p, q, eps=1e-12): p = np.asarray(p, dtype=float) q = np.asarray(q, dtype=float) mask = p > 0 return np.sum(p[mask] * np.log(p[mask] / (q[mask] + eps))) def js_div(p, q): assert (np.abs(np.sum(p) - 1) < 1e-6) and (np.abs(np.sum(q) - 1) < 1e-6) m = (p + q) / 2 return kl_div(p, m) / 2 + kl_div(q, m) / 2 def normalized(a, axis=-1, order=2): r""" Prediction Normalization """ l2 = np.atleast_1d(np.linalg.norm(a, order, axis)) l2[l2 == 0] = 1 return a / np.expand_dims(l2, axis) def alpha0_from_lamda(lam, n_test, n_classes): return 1+n_test*(1-lam)/(lam*n_classes)