log densities for kde

This commit is contained in:
Alejandro Moreo Fernandez 2026-01-27 17:58:53 +01:00
parent 839496fb8e
commit 877bfb2b18
6 changed files with 163 additions and 40 deletions

View File

@ -14,6 +14,7 @@ from numbers import Number
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
from jax.scipy.special import logsumexp
import numpyro import numpyro
import numpyro.distributions as dist import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS from numpyro.infer import MCMC, NUTS
@ -245,9 +246,16 @@ class BayesianKDEy(AggregativeSoftQuantifier, KDEBase, WithConfidenceABC):
def _bayesian_numpyro(self, X_probs): def _bayesian_numpyro(self, X_probs):
kdes = self.mix_densities kdes = self.mix_densities
test_densities = np.asarray( # test_densities = np.asarray(
[self.pdf(kde_i, X_probs, self.kernel) for kde_i in kdes] # [self.pdf(kde_i, X_probs, self.kernel) for kde_i in kdes]
# )
test_log_densities = np.asarray(
[self.pdf(kde_i, X_probs, self.kernel, log_densities=True) for kde_i in kdes]
) )
print(f'min={np.min(test_log_densities)}')
print(f'max={np.max(test_log_densities)}')
# import sys
# sys.exit(0)
n_classes = X_probs.shape[-1] n_classes = X_probs.shape[-1]
if isinstance(self.prior, str) and self.prior == 'uniform': if isinstance(self.prior, str) and self.prior == 'uniform':
@ -265,7 +273,7 @@ class BayesianKDEy(AggregativeSoftQuantifier, KDEBase, WithConfidenceABC):
) )
rng_key = jax.random.PRNGKey(self.mcmc_seed) rng_key = jax.random.PRNGKey(self.mcmc_seed)
mcmc.run(rng_key, test_densities=test_densities, alpha=alpha) mcmc.run(rng_key, test_log_densities=test_log_densities, alpha=alpha)
samples_z = mcmc.get_samples()["z"] samples_z = mcmc.get_samples()["z"]
@ -274,11 +282,11 @@ class BayesianKDEy(AggregativeSoftQuantifier, KDEBase, WithConfidenceABC):
return samples_prev return samples_prev
def _numpyro_model(self, test_densities, alpha): def _numpyro_model(self, test_log_densities, alpha):
""" """
test_densities: shape (n_classes, n_instances,) test_densities: shape (n_classes, n_instances,)
""" """
n_classes = test_densities.shape[0] n_classes = test_log_densities.shape[0]
# sample in unconstrained R^(n_classes-1) # sample in unconstrained R^(n_classes-1)
z = numpyro.sample( z = numpyro.sample(
@ -288,6 +296,19 @@ class BayesianKDEy(AggregativeSoftQuantifier, KDEBase, WithConfidenceABC):
prev = self.ilr.inverse(z) # simplex, shape (n_classes,) prev = self.ilr.inverse(z) # simplex, shape (n_classes,)
# for numerical stability:
# eps = 1e-10
# prev_safe = jnp.clip(prev, eps, 1.0)
# prev_safe = prev_safe / jnp.sum(prev_safe)
# prev = prev_safe
# from jax import debug
# debug.print("prev = {}", prev)
# numpyro.factor(
# "check_prev",
# jnp.where(jnp.all(prev > 0), 0.0, -jnp.inf)
# )
# prior # prior
if alpha is not None: if alpha is not None:
alpha = jnp.asarray(alpha) alpha = jnp.asarray(alpha)
@ -297,10 +318,20 @@ class BayesianKDEy(AggregativeSoftQuantifier, KDEBase, WithConfidenceABC):
# if alpha is None, then this corresponds to a weak logistic-normal prior # if alpha is None, then this corresponds to a weak logistic-normal prior
# likelihood # likelihood
test_densities = jnp.array(test_densities) # test_densities = jnp.array(test_densities)
likelihoods = jnp.dot(prev, test_densities) # likelihoods = jnp.dot(prev, test_densities)
# likelihoods = jnp.clip(likelihoods, 1e-12, jnp.inf) # numerical stability
# numpyro.factor(
# "loglik", (1.0 / self.temperature) * jnp.sum(jnp.log(likelihoods))
# )
log_likelihood = jnp.sum(
logsumexp(jnp.log(prev)[:, None] + test_log_densities, axis=0)
)
numpyro.factor( numpyro.factor(
"loglik", (1.0 / self.temperature) * jnp.sum(jnp.log(likelihoods + 1e-10)) "loglik", (1.0 / self.temperature) * log_likelihood
) )

View File

@ -12,37 +12,65 @@ from quapy.data import LabelledCollection
from quapy.method.aggregative import EMQ from quapy.method.aggregative import EMQ
def fetchMNIST(modality, data_home='./data/mnist_basiccnn'):
def fetchVisual(modality, dataset, net, data_home='./data'):
MODALITY = ('features', 'predictions', 'logits') MODALITY = ('features', 'predictions', 'logits')
assert modality in MODALITY, f'unknown modality, valid ones are {MODALITY}' assert modality in MODALITY, f'unknown modality, valid ones are {MODALITY}'
data_home = Path(data_home) file_prefix = f'{dataset}_{net}'
data_home = Path(data_home) / file_prefix
# Load training data # Load training data
train_data = np.load(data_home/'mnist_basiccnn_train_out.npz') train_data = np.load(data_home/f'{file_prefix}_train_out.npz')
train_X = train_data[modality] train_X = train_data[modality]
train_y = train_data['targets'] train_y = train_data['targets']
# Load validation data # Load validation data
val_data = np.load(data_home/'mnist_basiccnn_val_out.npz') val_data = np.load(data_home/f'{file_prefix}_val_out.npz')
val_X = val_data[modality] val_X = val_data[modality]
val_y = val_data['targets'] val_y = val_data['targets']
# Load test data # Load test data
test_data = np.load(data_home/'mnist_basiccnn_test_out.npz') test_data = np.load(data_home/f'{file_prefix}_test_out.npz')
test_X = test_data[modality] test_X = test_data[modality]
test_y = test_data['targets'] test_y = test_data['targets']
print(f'loaded MNIST ({modality=}): '
f'#train={len(train_y)}, #val={len(val_y)}, #test={len(test_y)}, #features={train_X.shape[1]}')
train = LabelledCollection(train_X, train_y) train = LabelledCollection(train_X, train_y)
val = LabelledCollection(val_X, val_y, classes=train.classes_) val = LabelledCollection(val_X, val_y, classes=train.classes_)
test = LabelledCollection(test_X, test_y, classes=train.classes_) test = LabelledCollection(test_X, test_y, classes=train.classes_)
def show_prev_stats(data:LabelledCollection):
p = data.prevalence()
return f'prevs in [{min(p)*100:.3f}%, {max(p)*100:.3f}%]'
print(f'loaded {dataset} ({modality=}): '
f'#train={len(train)}({show_prev_stats(train)}), '
f'#val={len(val)}({show_prev_stats(val)}), '
f'#test={len(test)}({show_prev_stats(test)}), '
f'#features={train_X.shape[1]}, '
f'#classes={len(set(train_y))}')
return train, val, test return train, val, test
def fetchMNIST(modality, data_home='./data'):
return fetchVisual(modality, dataset='mnist', net='basiccnn', data_home=data_home)
def fetchCIFAR100coarse(modality, data_home='./data'):
return fetchVisual(modality, dataset='cifar100coarse', net='resnet18', data_home=data_home)
def fetchCIFAR100(modality, data_home='./data'):
return fetchVisual(modality, dataset='cifar100', net='resnet18', data_home=data_home)
def fetchCIFAR10(modality, data_home='./data'):
return fetchVisual(modality, dataset='cifar10', net='resnet18', data_home=data_home)
def fetchFashionMNIST(modality, data_home='./data'):
return fetchVisual(modality, dataset='fashionmnist', net='basiccnn', data_home=data_home)
def fetchSVHN(modality, data_home='./data'):
return fetchVisual(modality, dataset='svhn', net='resnet18', data_home=data_home)
class DatasetHandler(ABC): class DatasetHandler(ABC):
@ -87,10 +115,19 @@ class DatasetHandler(ABC):
def is_binary(self): ... def is_binary(self): ...
class MNISTHandler(DatasetHandler): class VisualDataHandler(DatasetHandler):
def __init__(self, n_val_samples=100, n_test_samples=100, sample_size=500, random_state=0): def __init__(self, name, n_val_samples=100, n_test_samples=100, sample_size=500, random_state=0):
super().__init__(name='MNIST') # mode features : feature-based, the idea is to learn a LogisticRegression on top
# mode predictions : posterior probabilities
# assert modality in ['features', 'predictions'], f'unknown {modality=}'
super().__init__(name=name)
modality = 'predictions'
if name.endswith('-f'):
modality = 'features'
elif name.endswith('-l'):
modality = 'logits'
self.modality = modality
self.n_val_samples = n_val_samples self.n_val_samples = n_val_samples
self.n_test_samples = n_test_samples self.n_test_samples = n_test_samples
self.sample_size = sample_size self.sample_size = sample_size
@ -98,35 +135,61 @@ class MNISTHandler(DatasetHandler):
@lru_cache(maxsize=None) @lru_cache(maxsize=None)
def dataset(self): def dataset(self):
return fetchMNIST(modality='predictions') name = self.name.lower()
name = name.replace('-f', '')
name = name.replace('-l', '')
if name=='mnist':
data = fetchMNIST(modality=self.modality)
elif name=='cifar100coarse':
data = fetchCIFAR100coarse(modality=self.modality)
elif name=='cifar100':
data = fetchCIFAR100(modality=self.modality)
elif name=='cifar10':
data = fetchCIFAR10(modality=self.modality)
elif name=='fashionmnist':
data = fetchFashionMNIST(modality=self.modality)
elif name=='svhn':
data = fetchSVHN(modality=self.modality)
else:
raise ValueError(f'unknown dataset {name}')
# the training set was used to extract features;
# we use the validation portion as a training set for quantifiers
net_train, val, test = data
train, val = val.split_stratified(train_prop=0.6, random_state=self.random_state)
return train, val, test
def get_training(self): def get_training(self):
return self.dataset()[0] train, val, test = self.dataset()
return train
def get_validation(self): def get_validation(self):
return self.dataset()[1] train, val, test = self.dataset()
return val
def get_train_testprot_for_eval(self): def get_train_testprot_for_eval(self):
# note that the training goes on the validation split, since the proper training was used for training the neural network train, val, test = self.dataset()
_, val, test = self.dataset()
test_prot = UPP(test, sample_size=self.sample_size, repeats=self.n_test_samples, random_state=self.random_state) test_prot = UPP(test, sample_size=self.sample_size, repeats=self.n_test_samples, random_state=self.random_state)
return val, test_prot return train+val, test_prot
def get_train_valprot_for_modsel(self): def get_train_valprot_for_modsel(self):
# the training split is never used (was used to train a neural model) train, val, test = self.dataset()
# we consider the validation split as our training data, so we return a new split on it
_, val, _ = self.dataset()
train, val = val.split_stratified(train_prop=0.6, random_state=self.random_state)
val_prot = UPP(val, sample_size=self.sample_size, repeats=self.n_val_samples, random_state=self.random_state) val_prot = UPP(val, sample_size=self.sample_size, repeats=self.n_val_samples, random_state=self.random_state)
return train, val_prot return train, val_prot
@classmethod @classmethod
def get_datasets(cls): def get_datasets(cls):
return ['MNIST'] datasets = ['cifar10', 'mnist', 'cifar100coarse', 'fashionmnist', 'svhn'] #+ ['cifar100']
# datasets_feat = [f'{d}-f' for d in datasets]
datasets_feat = [f'{d}-l' for d in datasets]
return datasets_feat # + datasets
@classmethod @classmethod
def iter(cls, **kwargs): def iter(cls, **kwargs):
yield cls(**kwargs) for name in cls.get_datasets():
yield cls(name, **kwargs)
def __repr__(self): def __repr__(self):
return f'{self.name}' return f'{self.name}'
@ -136,6 +199,18 @@ class MNISTHandler(DatasetHandler):
return False return False
class CIFAR100Handler(VisualDataHandler):
def __init__(self, name, n_val_samples=100, n_test_samples=100, sample_size=2000, random_state=0):
super().__init__(name=name, n_val_samples=n_val_samples, n_test_samples=n_test_samples, sample_size=sample_size, random_state=random_state)
@classmethod
def get_datasets(cls):
datasets = ['cifar100']
# datasets_feat = [f'{d}-f' for d in datasets]
datasets_feat = [f'{d}-l' for d in datasets]
return datasets_feat # + datasets
# LeQua multiclass tasks # LeQua multiclass tasks
class LeQuaHandler(DatasetHandler): class LeQuaHandler(DatasetHandler):

View File

@ -0,0 +1,9 @@
import quapy as qp
from quapy.error import mae
import quapy.functional as F
for n in range(2,100,5):
a = F.uniform_prevalence_sampling(n_classes=n, size=10_000)
b = F.uniform_prevalence_sampling(n_classes=n, size=10_000)
print(f'{n=} ae={mae(a, b):.5f}')

View File

@ -7,7 +7,7 @@ from _bayeisan_kdey import BayesianKDEy
from _bayesian_mapls import BayesianMAPLS from _bayesian_mapls import BayesianMAPLS
from commons import experiment_path, KDEyCLR, RESULT_DIR, MockClassifierFromPosteriors from commons import experiment_path, KDEyCLR, RESULT_DIR, MockClassifierFromPosteriors
# import datasets # import datasets
from datasets import LeQuaHandler, UCIMulticlassHandler, DatasetHandler, MNISTHandler from datasets import LeQuaHandler, UCIMulticlassHandler, DatasetHandler, VisualDataHandler, CIFAR100Handler
from temperature_calibration import temp_calibration from temperature_calibration import temp_calibration
from build.lib.quapy.data import LabelledCollection from build.lib.quapy.data import LabelledCollection
from quapy.method.aggregative import DistributionMatchingY as DMy, AggregativeQuantifier, EMQ, CC from quapy.method.aggregative import DistributionMatchingY as DMy, AggregativeQuantifier, EMQ, CC
@ -32,10 +32,11 @@ def methods(data_handler: DatasetHandler):
- bayesian/bootstrap_constructor: is a function that instantiates the bayesian o bootstrap method with the - bayesian/bootstrap_constructor: is a function that instantiates the bayesian o bootstrap method with the
quantifier with optimized hyperparameters quantifier with optimized hyperparameters
""" """
if isinstance(data_handler, MNISTHandler): if False: # isinstance(data_handler, VisualDataHandler):
Cls = MockClassifierFromPosteriors Cls = MockClassifierFromPosteriors
cls_hyper = {} cls_hyper = {}
val_split = data_handler.get_validation().Xy # use this specific collection val_split = data_handler.get_validation().Xy # use this specific collection
pass
else: else:
Cls = LogisticRegression Cls = LogisticRegression
cls_hyper = {'classifier__C': np.logspace(-4,4,9), 'classifier__class_weight': ['balanced', None]} cls_hyper = {'classifier__C': np.logspace(-4,4,9), 'classifier__class_weight': ['balanced', None]}
@ -69,9 +70,9 @@ def methods(data_handler: DatasetHandler):
# -------------------------------------------------------------------------------------------------------- # --------------------------------------------------------------------------------------------------------
# yield 'BayesianACC', acc, acc_hyper, lambda hyper: BayesianCC(Cls(), val_split=val_split, mcmc_seed=0), multiclass_method # yield 'BayesianACC', acc, acc_hyper, lambda hyper: BayesianCC(Cls(), val_split=val_split, mcmc_seed=0), multiclass_method
#yield 'BayesianHDy', hdy, hdy_hyper, lambda hyper: PQ(Cls(), val_split=val_split, stan_seed=0, **hyper), only_binary #yield 'BayesianHDy', hdy, hdy_hyper, lambda hyper: PQ(Cls(), val_split=val_split, stan_seed=0, **hyper), only_binary
yield f'BaKDE-Ait-numpyro', kde_ait, kdey_hyper_clr, lambda hyper: BayesianKDEy(Cls(), kernel='aitchison', mcmc_seed=0, engine='numpyro', val_split=val_split, **hyper), multiclass_method # yield f'BaKDE-Ait-numpyro', kde_ait, kdey_hyper_clr, lambda hyper: BayesianKDEy(Cls(), kernel='aitchison', mcmc_seed=0, engine='numpyro', val_split=val_split, **hyper), multiclass_method
yield f'BaKDE-Gau-numpyro', kde_gau, kdey_hyper, lambda hyper: BayesianKDEy(Cls(), kernel='gaussian', mcmc_seed=0, engine='numpyro', val_split=val_split, **hyper), multiclass_method yield f'BaKDE-Gau-numpyro', kde_gau, kdey_hyper, lambda hyper: BayesianKDEy(Cls(), kernel='gaussian', mcmc_seed=0, engine='numpyro', val_split=val_split, **hyper), multiclass_method
yield f'BaKDE-Ait-T*', kde_ait, kdey_hyper_clr, lambda hyper: BayesianKDEy(Cls(),kernel='aitchison', mcmc_seed=0, engine='numpyro', temperature=None, val_split=val_split, **hyper), multiclass_method # yield f'BaKDE-Ait-T*', kde_ait, kdey_hyper_clr, lambda hyper: BayesianKDEy(Cls(),kernel='aitchison', mcmc_seed=0, engine='numpyro', temperature=None, val_split=val_split, **hyper), multiclass_method
yield f'BaKDE-Gau-T*', kde_gau, kdey_hyper, lambda hyper: BayesianKDEy(Cls(), kernel='gaussian', mcmc_seed=0, engine='numpyro', temperature=None, val_split=val_split, **hyper), multiclass_method yield f'BaKDE-Gau-T*', kde_gau, kdey_hyper, lambda hyper: BayesianKDEy(Cls(), kernel='gaussian', mcmc_seed=0, engine='numpyro', temperature=None, val_split=val_split, **hyper), multiclass_method
# yield 'BayEMQ', emq, acc_hyper, lambda hyper: BayesianMAPLS(Cls(), prior='uniform', temperature=1, exact_train_prev=False, val_split=val_split), multiclass_method # yield 'BayEMQ', emq, acc_hyper, lambda hyper: BayesianMAPLS(Cls(), prior='uniform', temperature=1, exact_train_prev=False, val_split=val_split), multiclass_method
# yield 'BayEMQ*', emq, acc_hyper, lambda hyper: BayesianMAPLS(Cls(), prior='uniform', temperature=None, exact_train_prev=False, val_split=val_split), multiclass_method # yield 'BayEMQ*', emq, acc_hyper, lambda hyper: BayesianMAPLS(Cls(), prior='uniform', temperature=None, exact_train_prev=False, val_split=val_split), multiclass_method
@ -124,6 +125,7 @@ def experiment(dataset: DatasetHandler, point_quantifier: AggregativeQuantifier,
best_hyperparams = qp.util.pickled_resource( best_hyperparams = qp.util.pickled_resource(
hyper_choice_path, model_selection, dataset, cp(point_quantifier), grid hyper_choice_path, model_selection, dataset, cp(point_quantifier), grid
) )
print(f'{best_hyperparams=}')
t_init = time() t_init = time()
uncertainty_quantifier = uncertainty_quant_constructor(best_hyperparams) uncertainty_quantifier = uncertainty_quant_constructor(best_hyperparams)
@ -177,7 +179,7 @@ if __name__ == '__main__':
result_dir = RESULT_DIR result_dir = RESULT_DIR
for data_handler in [MNISTHandler]:#, UCIMulticlassHandler,LeQuaHandler]: for data_handler in [CIFAR100Handler, VisualDataHandler]:#, UCIMulticlassHandler,LeQuaHandler]:
for dataset in data_handler.iter(): for dataset in data_handler.iter():
qp.environ['SAMPLE_SIZE'] = dataset.sample_size qp.environ['SAMPLE_SIZE'] = dataset.sample_size
print(f'dataset={dataset.name}') print(f'dataset={dataset.name}')

View File

@ -8,7 +8,7 @@ from glob import glob
from pathlib import Path from pathlib import Path
import quapy as qp import quapy as qp
from BayesianKDEy.commons import RESULT_DIR from BayesianKDEy.commons import RESULT_DIR
from BayesianKDEy.datasets import LeQuaHandler, UCIMulticlassHandler, MNISTHandler from BayesianKDEy.datasets import LeQuaHandler, UCIMulticlassHandler, VisualDataHandler, CIFAR100Handler
from error import dist_aitchison from error import dist_aitchison
from quapy.method.confidence import ConfidenceIntervals from quapy.method.confidence import ConfidenceIntervals
from quapy.method.confidence import ConfidenceEllipseSimplex, ConfidenceEllipseCLR, ConfidenceEllipseILR, ConfidenceIntervals, ConfidenceRegionABC from quapy.method.confidence import ConfidenceEllipseSimplex, ConfidenceEllipseCLR, ConfidenceEllipseILR, ConfidenceIntervals, ConfidenceRegionABC
@ -119,7 +119,7 @@ n_classes = {}
tr_size = {} tr_size = {}
tr_prev = {} tr_prev = {}
for dataset_handler in [UCIMulticlassHandler, LeQuaHandler, MNISTHandler]: for dataset_handler in [UCIMulticlassHandler, LeQuaHandler, VisualDataHandler, CIFAR100Handler]:
problem_type = 'binary' if dataset_handler.is_binary() else 'multiclass' problem_type = 'binary' if dataset_handler.is_binary() else 'multiclass'
path = f'./{base_dir}/{problem_type}/*.pkl' path = f'./{base_dir}/{problem_type}/*.pkl'

View File

@ -61,7 +61,7 @@ class KDEBase:
return KernelDensity(bandwidth=bandwidth).fit(X) return KernelDensity(bandwidth=bandwidth).fit(X)
def pdf(self, kde, X, kernel): def pdf(self, kde, X, kernel, log_densities=False):
""" """
Wraps the density evalution of scikit-learn's KDE. Scikit-learn returns log-scores (s), so this Wraps the density evalution of scikit-learn's KDE. Scikit-learn returns log-scores (s), so this
function returns :math:`e^{s}` function returns :math:`e^{s}`
@ -76,7 +76,11 @@ class KDEBase:
elif kernel == 'ilr': elif kernel == 'ilr':
X = self.ilr_transform(X) X = self.ilr_transform(X)
return np.exp(kde.score_samples(X)) log_density = kde.score_samples(X)
if log_densities:
return log_density
else:
return np.exp(log_density)
def get_mixture_components(self, X, y, classes, bandwidth, kernel): def get_mixture_components(self, X, y, classes, bandwidth, kernel):
""" """
@ -93,6 +97,8 @@ class KDEBase:
for cat in classes: for cat in classes:
selX = X[y==cat] selX = X[y==cat]
if selX.size==0: if selX.size==0:
print(f'WARNING: empty class {cat}')
raise ValueError(f'WARNING: empty class {cat}')
selX = [F.uniform_prevalence(len(classes))] selX = [F.uniform_prevalence(len(classes))]
class_cond_X.append(selX) class_cond_X.append(selX)