adding nuts to kdey
This commit is contained in:
parent
89a8cad0b3
commit
e33b291357
|
|
@ -1,3 +1,5 @@
|
|||
from functools import lru_cache
|
||||
|
||||
from numpy.ma.core import shape
|
||||
from sklearn.base import BaseEstimator
|
||||
import numpy as np
|
||||
|
|
@ -5,13 +7,19 @@ import numpy as np
|
|||
import quapy.util
|
||||
from quapy.method._kdey import KDEBase
|
||||
from quapy.method.confidence import WithConfidenceABC, ConfidenceRegionABC
|
||||
from quapy.functional import CLRtransformation, ILRtransformation
|
||||
from quapy.functional import CLRtransformation
|
||||
from quapy.method.aggregative import AggregativeSoftQuantifier
|
||||
from tqdm import tqdm
|
||||
import quapy.functional as F
|
||||
#import emcee
|
||||
import emcee
|
||||
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import numpyro
|
||||
import numpyro.distributions as dist
|
||||
from numpyro.infer import MCMC, NUTS
|
||||
|
||||
|
||||
class BayesianKDEy(AggregativeSoftQuantifier, KDEBase, WithConfidenceABC):
|
||||
"""
|
||||
|
|
@ -49,6 +57,7 @@ class BayesianKDEy(AggregativeSoftQuantifier, KDEBase, WithConfidenceABC):
|
|||
explore='simplex',
|
||||
step_size=0.05,
|
||||
temperature=1.,
|
||||
engine='numpyro',
|
||||
verbose: bool = False):
|
||||
|
||||
if num_warmup <= 0:
|
||||
|
|
@ -58,6 +67,7 @@ class BayesianKDEy(AggregativeSoftQuantifier, KDEBase, WithConfidenceABC):
|
|||
assert explore in ['simplex', 'clr', 'ilr'], \
|
||||
f'unexpected value for param {explore=}; valid ones are "simplex", "clr", and "ilr"'
|
||||
assert temperature>0., f'temperature must be >0'
|
||||
assert engine in ['rw-mh', 'emcee', 'numpyro']
|
||||
|
||||
super().__init__(classifier, fit_classifier, val_split)
|
||||
self.bandwidth = KDEBase._check_bandwidth(bandwidth, kernel)
|
||||
|
|
@ -70,6 +80,7 @@ class BayesianKDEy(AggregativeSoftQuantifier, KDEBase, WithConfidenceABC):
|
|||
self.explore = explore
|
||||
self.step_size = step_size
|
||||
self.temperature = temperature
|
||||
self.engine = engine
|
||||
self.verbose = verbose
|
||||
|
||||
def aggregation_fit(self, classif_predictions, labels):
|
||||
|
|
@ -77,8 +88,12 @@ class BayesianKDEy(AggregativeSoftQuantifier, KDEBase, WithConfidenceABC):
|
|||
return self
|
||||
|
||||
def aggregate(self, classif_predictions):
|
||||
# self.prevalence_samples = self._bayesian_kde(classif_predictions, init=None, verbose=self.verbose)
|
||||
self.prevalence_samples = self._bayesian_emcee(classif_predictions)
|
||||
if self.engine == 'rw-mh':
|
||||
self.prevalence_samples = self._bayesian_kde(classif_predictions, init=None, verbose=self.verbose)
|
||||
elif self.engine == 'emcee':
|
||||
self.prevalence_samples = self._bayesian_emcee(classif_predictions)
|
||||
elif self.engine == 'numpyro':
|
||||
self.prevalence_samples = self._bayesian_numpyro(classif_predictions)
|
||||
return self.prevalence_samples.mean(axis=0)
|
||||
|
||||
def predict_conf(self, instances, confidence_level=None) -> (np.ndarray, ConfidenceRegionABC):
|
||||
|
|
@ -187,6 +202,8 @@ class BayesianKDEy(AggregativeSoftQuantifier, KDEBase, WithConfidenceABC):
|
|||
def _bayesian_emcee(self, X_probs):
|
||||
ndim = X_probs.shape[1]
|
||||
nwalkers = 32
|
||||
if nwalkers < (ndim*2):
|
||||
nwalkers = ndim * 2 + 1
|
||||
|
||||
f = CLRtransformation()
|
||||
|
||||
|
|
@ -198,6 +215,7 @@ class BayesianKDEy(AggregativeSoftQuantifier, KDEBase, WithConfidenceABC):
|
|||
|
||||
kdes = self.mix_densities
|
||||
test_densities = np.asarray([self.pdf(kde_i, X_probs, self.kernel) for kde_i in kdes])
|
||||
# test_densities_unconstrained = [f(t) for t in test_densities]
|
||||
|
||||
# p0 = np.random.normal(nwalkers, ndim)
|
||||
p0 = F.uniform_prevalence_sampling(ndim, nwalkers)
|
||||
|
|
@ -211,5 +229,114 @@ class BayesianKDEy(AggregativeSoftQuantifier, KDEBase, WithConfidenceABC):
|
|||
samples = f.inverse(samples)
|
||||
return samples
|
||||
|
||||
def _bayesian_numpyro(self, X_probs):
|
||||
kdes = self.mix_densities
|
||||
test_densities = np.asarray(
|
||||
[self.pdf(kde_i, X_probs, self.kernel) for kde_i in kdes]
|
||||
)
|
||||
|
||||
# move to jax
|
||||
test_densities = jnp.array(test_densities)
|
||||
|
||||
kernel = NUTS(self._numpyro_model)
|
||||
mcmc = MCMC(
|
||||
kernel,
|
||||
num_warmup=self.num_warmup,
|
||||
num_samples=self.num_samples,
|
||||
num_chains=1,
|
||||
progress_bar=self.verbose,
|
||||
)
|
||||
|
||||
rng_key = jax.random.PRNGKey(self.mcmc_seed)
|
||||
mcmc.run(rng_key, test_densities)
|
||||
|
||||
samples_z = mcmc.get_samples()["z"]
|
||||
|
||||
# back to simplex
|
||||
ilr = ILRtransformation(jax_mode=True)
|
||||
samples_prev = np.asarray(ilr.inverse(np.asarray(samples_z)))
|
||||
|
||||
return samples_prev
|
||||
|
||||
def _numpyro_model(self, test_densities):
|
||||
"""
|
||||
test_densities: shape (C, N)
|
||||
"""
|
||||
C = test_densities.shape[0]
|
||||
ilr = ILRtransformation(jax_mode=True)
|
||||
|
||||
# sample in unconstrained R^{C-1}
|
||||
z = numpyro.sample(
|
||||
"z",
|
||||
dist.Normal(0.0, 1.0).expand([C - 1])
|
||||
)
|
||||
|
||||
prev = ilr.inverse(z) # simplex, shape (C,)
|
||||
|
||||
# likelihood
|
||||
likelihoods = jnp.dot(prev, test_densities)
|
||||
numpyro.factor(
|
||||
"loglik",
|
||||
(1.0 / self.temperature) * jnp.sum(jnp.log(likelihoods + 1e-10))
|
||||
)
|
||||
|
||||
|
||||
def in_simplex(x):
|
||||
return np.all(x >= 0) and np.isclose(x.sum(), 1)
|
||||
|
||||
|
||||
|
||||
|
||||
class ILRtransformation(F.CompositionalTransformation):
|
||||
def __init__(self, jax_mode=False):
|
||||
self.jax_mode = jax_mode
|
||||
|
||||
def array(self, X):
|
||||
if self.jax_mode:
|
||||
return jnp.array(X)
|
||||
else:
|
||||
return np.asarray(X)
|
||||
|
||||
def __call__(self, X):
|
||||
X = self.array(X)
|
||||
X = quapy.error.smooth(X, self.EPSILON)
|
||||
k = X.shape[-1]
|
||||
V = self.array(self.get_V(k))
|
||||
logp = jnp.log(X) if self.jax_mode else np.log(X)
|
||||
return logp @ V.T
|
||||
|
||||
def inverse(self, Z):
|
||||
Z = self.array(Z)
|
||||
k_minus_1 = Z.shape[-1]
|
||||
k = k_minus_1 + 1
|
||||
V = self.array(self.get_V(k))
|
||||
logp = Z @ V
|
||||
p = jnp.exp(logp) if self.jax_mode else np.exp(logp)
|
||||
p = p / jnp.sum(p, axis=-1, keepdims=True) if self.jax_mode else p / np.sum(p, axis=-1, keepdims=True)
|
||||
return p
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
def get_V(self, k):
|
||||
def helmert_matrix(k):
|
||||
"""
|
||||
Returns the (k x k) Helmert matrix.
|
||||
"""
|
||||
H = np.zeros((k, k))
|
||||
for i in range(1, k):
|
||||
H[i, :i] = 1
|
||||
H[i, i] = -(i)
|
||||
H[i] = H[i] / np.sqrt(i * (i + 1))
|
||||
# row 0 stays zeros; will be discarded
|
||||
return H
|
||||
|
||||
def ilr_basis(k):
|
||||
"""
|
||||
Constructs an orthonormal ILR basis using the Helmert submatrix.
|
||||
Output shape: (k-1, k)
|
||||
"""
|
||||
H = helmert_matrix(k)
|
||||
V = H[1:, :] # remove first row of zeros
|
||||
return V
|
||||
|
||||
return ilr_basis(k)
|
||||
|
||||
|
|
|
|||
|
|
@ -64,9 +64,9 @@ def methods():
|
|||
only_multiclass = 'only_multiclass'
|
||||
|
||||
# yield 'BootstrapACC', ACC(LR()), acc_hyper, lambda hyper: AggregativeBootstrap(ACC(LR()), n_test_samples=1000, random_state=0), multiclass_method
|
||||
# yield 'BayesianACC', ACC(LR()), acc_hyper, lambda hyper: BayesianCC(LR(), mcmc_seed=0), multiclass_method
|
||||
yield 'BayesianACC', ACC(LR()), acc_hyper, lambda hyper: BayesianCC(LR(), mcmc_seed=0), multiclass_method
|
||||
|
||||
yield 'BootstrapEMQ', EMQ(LR(), on_calib_error='backup', val_split=5), emq_hyper, lambda hyper: AggregativeBootstrap(EMQ(LR(), on_calib_error='backup', calib=hyper['calib'], val_split=5), n_test_samples=1000, random_state=0), multiclass_method
|
||||
# yield 'BootstrapEMQ', EMQ(LR(), on_calib_error='backup', val_split=5), emq_hyper, lambda hyper: AggregativeBootstrap(EMQ(LR(), on_calib_error='backup', calib=hyper['calib'], val_split=5), n_test_samples=1000, random_state=0), multiclass_method
|
||||
|
||||
# yield 'BootstrapHDy', DMy(LR()), hdy_hyper, lambda hyper: AggregativeBootstrap(DMy(LR(), **hyper), n_test_samples=1000, random_state=0), multiclass_method
|
||||
# yield 'BayesianHDy', DMy(LR()), hdy_hyper, lambda hyper: PQ(LR(), stan_seed=0, **hyper), only_binary
|
||||
|
|
@ -78,6 +78,8 @@ def methods():
|
|||
# yield 'BayKDEy*CLR2', KDEyCLR(LR()), kdey_hyper_clr, lambda hyper: BayesianKDEy(kernel='aitchison', mcmc_seed=0, explore='clr', step_size=.05, **hyper), multiclass_method
|
||||
# yield 'BayKDEy*ILR', KDEyCLR(LR()), kdey_hyper_clr, lambda hyper: BayesianKDEy(kernel='aitchison', mcmc_seed=0, explore='ilr', step_size=.15, **hyper), only_multiclass
|
||||
# yield 'BayKDEy*ILR2', KDEyILR(LR()), kdey_hyper_clr, lambda hyper: BayesianKDEy(kernel='ilr', mcmc_seed=0, explore='ilr', step_size=.1, **hyper), only_multiclass
|
||||
yield f'BaKDE-emcee', KDEyML(LR()), kdey_hyper, lambda hyper: BayesianKDEy(mcmc_seed=0, num_warmup=100, num_samples=100, step_size=.1, engine='emcee', **hyper), multiclass_method
|
||||
yield f'BaKDE-numpyro', KDEyML(LR()), kdey_hyper, lambda hyper: BayesianKDEy(mcmc_seed=0, step_size=.1, engine='numpyro', **hyper), multiclass_method
|
||||
|
||||
|
||||
def model_selection(train: LabelledCollection, point_quantifier: AggregativeQuantifier, grid: dict):
|
||||
|
|
@ -153,18 +155,18 @@ if __name__ == '__main__':
|
|||
binary = {
|
||||
'datasets': qp.datasets.UCI_BINARY_DATASETS,
|
||||
'fetch_fn': qp.datasets.fetch_UCIBinaryDataset,
|
||||
'sample_size': 500
|
||||
'sample_size': 100 # previous: 500
|
||||
}
|
||||
|
||||
multiclass = {
|
||||
'datasets': qp.datasets.UCI_MULTICLASS_DATASETS,
|
||||
'fetch_fn': qp.datasets.fetch_UCIMulticlassDataset,
|
||||
'sample_size': 1000
|
||||
'sample_size': 200 # previous: 1000
|
||||
}
|
||||
|
||||
result_dir = Path('./results')
|
||||
|
||||
for setup in [binary, multiclass]: # [binary, multiclass]:
|
||||
for setup in [multiclass]: # [binary, multiclass]:
|
||||
qp.environ['SAMPLE_SIZE'] = setup['sample_size']
|
||||
for data_name in setup['datasets']:
|
||||
print(f'dataset={data_name}')
|
||||
|
|
|
|||
|
|
@ -66,16 +66,21 @@ def update_pickle_with_region(report, file, conf_name, conf_region_class, **kwar
|
|||
if f'coverage-{conf_name}' not in report:
|
||||
covs, amps, winkler = compute_coverage_amplitude(conf_region_class, **kwargs)
|
||||
|
||||
# amperr (lower is better) counts the amplitude when the true vale was covered, or 1 (max amplitude) otherwise
|
||||
amperrs = [amp if cov == 1.0 else 1. for amp, cov in zip(amps, covs)]
|
||||
|
||||
update_fields = {
|
||||
f'coverage-{conf_name}': covs,
|
||||
f'amplitude-{conf_name}': amps,
|
||||
f'winkler-{conf_name}': winkler
|
||||
f'winkler-{conf_name}': winkler,
|
||||
f'amperr-{conf_name}': amperrs,
|
||||
}
|
||||
|
||||
update_pickle(report, file, update_fields)
|
||||
|
||||
methods = None # show all methods
|
||||
# methods = ['BayesianACC', 'BayesianKDEy']
|
||||
|
||||
# methods = None # show all methods
|
||||
methods = ['BayesianACC', 'BayesianKDEy', 'BaKDE-emcee', 'BaKDE-numpyro']
|
||||
|
||||
for setup in ['multiclass']:
|
||||
path = f'./results/{setup}/*.pkl'
|
||||
|
|
@ -103,22 +108,26 @@ for setup in ['multiclass']:
|
|||
table['c-CI'].extend(report['coverage-CI'])
|
||||
table['a-CI'].extend(report['amplitude-CI'])
|
||||
table['w-CI'].extend(report['winkler-CI'])
|
||||
table['amperr-CI'].extend(report['amperr-CI'])
|
||||
|
||||
table['c-CE'].extend(report['coverage-CE'])
|
||||
table['a-CE'].extend(report['amplitude-CE'])
|
||||
table['amperr-CE'].extend(report['amperr-CE'])
|
||||
|
||||
table['c-CLR'].extend(report['coverage-CLR'])
|
||||
table['a-CLR'].extend(report['amplitude-CLR'])
|
||||
table['amperr-CLR'].extend(report['amperr-CLR'])
|
||||
|
||||
table['c-ILR'].extend(report['coverage-ILR'])
|
||||
table['a-ILR'].extend(report['amplitude-ILR'])
|
||||
table['amperr-ILR'].extend(report['amperr-ILR'])
|
||||
|
||||
table['aitch'].extend(qp.error.dist_aitchison(results['true-prevs'], results['point-estim']))
|
||||
# table['aitch-well'].extend(qp.error.dist_aitchison(results['true-prevs'], [ConfidenceEllipseILR(samples).mean_ for samples in results['samples']]))
|
||||
# table['aitch'].extend()
|
||||
table['reg-score-ILR'].extend(
|
||||
[region_score(true_prev, ConfidenceEllipseILR(samples)) for true_prev, samples in zip(results['true-prevs'], results['samples'])]
|
||||
)
|
||||
# table['reg-score-ILR'].extend(
|
||||
# [region_score(true_prev, ConfidenceEllipseILR(samples)) for true_prev, samples in zip(results['true-prevs'], results['samples'])]
|
||||
# )
|
||||
|
||||
|
||||
|
||||
|
|
@ -145,7 +154,7 @@ for setup in ['multiclass']:
|
|||
# if n < min_train:
|
||||
# df = df[df["dataset"] != data_name]
|
||||
|
||||
for region in ['ILR']: # , 'CI', 'CE', 'CLR', 'ILR']:
|
||||
for region in ['CI', 'CE', 'CLR', 'ILR']:
|
||||
if setup == 'binary' and region=='ILR':
|
||||
continue
|
||||
# pv = pd.pivot_table(
|
||||
|
|
@ -153,12 +162,15 @@ for setup in ['multiclass']:
|
|||
# )
|
||||
pv = pd.pivot_table(
|
||||
df, index='dataset', columns='method', values=[
|
||||
f'amperr-{region}',
|
||||
f'a-{region}',
|
||||
f'c-{region}',
|
||||
#f'w-{region}',
|
||||
# 'ae',
|
||||
'ae',
|
||||
# 'rae',
|
||||
# f'aitch',
|
||||
# f'aitch-well'
|
||||
'reg-score-ILR',
|
||||
# 'reg-score-ILR',
|
||||
], margins=True
|
||||
)
|
||||
pv['n_classes'] = pv.index.map(n_classes).astype('Int64')
|
||||
|
|
|
|||
|
|
@ -68,8 +68,8 @@ if __name__ == '__main__':
|
|||
|
||||
setup = multiclass
|
||||
# data_name = 'isolet'
|
||||
# data_name = 'cmc'
|
||||
data_name = 'abalone'
|
||||
data_name = 'academic-success'
|
||||
# data_name = 'abalone'
|
||||
|
||||
qp.environ['SAMPLE_SIZE'] = setup['sample_size']
|
||||
print(f'dataset={data_name}')
|
||||
|
|
|
|||
|
|
@ -1,13 +1,14 @@
|
|||
Change Log 0.2.1
|
||||
-----------------
|
||||
|
||||
- Added DirichletProtocol, which allows to generate samples according to a parameterized Dirichlet prior.
|
||||
- Added squared ratio error.
|
||||
- Improved efficiency of confidence regions coverage functions
|
||||
- Added Precise Quantifier to WithConfidence methods (a Bayesian adaptation of HDy)
|
||||
- Improved documentation of confidence regions.
|
||||
- Added ReadMe method by Daniel Hopkins and Gary King
|
||||
- Internal index in LabelledCollection is now "lazy", and is only constructed if required.
|
||||
- I have added dist_aitchison and mean_dist_aitchison as a new error evaluation metric
|
||||
- Added dist_aitchison and mean_dist_aitchison as a new error evaluation metric.
|
||||
|
||||
Change Log 0.2.0
|
||||
-----------------
|
||||
|
|
|
|||
|
|
@ -758,7 +758,7 @@ def fetch_UCIMulticlassLabelledCollection(dataset_name, data_home=None, min_clas
|
|||
# restrict classes to only those with at least min_ipc instances
|
||||
classes = classes[data.counts() >= min_ipc]
|
||||
# filter X and y keeping only datapoints belonging to valid classes
|
||||
filter_idx = np.in1d(data.y, classes)
|
||||
filter_idx = np.isin(data.y, classes)
|
||||
X, y = data.X[filter_idx], data.y[filter_idx]
|
||||
# map classes to range(len(classes))
|
||||
y = np.searchsorted(classes, y)
|
||||
|
|
|
|||
|
|
@ -33,7 +33,7 @@ P_TEST_C: str = "P_test(C)"
|
|||
P_C_COND_Y: str = "P(C|Y)"
|
||||
|
||||
|
||||
def model(n_c_unlabeled: np.ndarray, n_y_and_c_labeled: np.ndarray) -> None:
|
||||
def model_bayesianCC(n_c_unlabeled: np.ndarray, n_y_and_c_labeled: np.ndarray) -> None:
|
||||
"""
|
||||
Defines a probabilistic model in `NumPyro <https://num.pyro.ai/>`_.
|
||||
|
||||
|
|
@ -57,6 +57,7 @@ def model(n_c_unlabeled: np.ndarray, n_y_and_c_labeled: np.ndarray) -> None:
|
|||
numpyro.sample('N_c', dist.Multinomial(jnp.sum(n_c_unlabeled), p_c), obs=n_c_unlabeled)
|
||||
|
||||
|
||||
|
||||
def sample_posterior(
|
||||
n_c_unlabeled: np.ndarray,
|
||||
n_y_and_c_labeled: np.ndarray,
|
||||
|
|
@ -78,7 +79,7 @@ def sample_posterior(
|
|||
:return: a `dict` with the samples. The keys are the names of the latent variables.
|
||||
"""
|
||||
mcmc = numpyro.infer.MCMC(
|
||||
numpyro.infer.NUTS(model),
|
||||
numpyro.infer.NUTS(model_bayesianCC),
|
||||
num_warmup=num_warmup,
|
||||
num_samples=num_samples,
|
||||
progress_bar=False
|
||||
|
|
|
|||
|
|
@ -171,7 +171,7 @@ class AbstractStochasticSeededProtocol(AbstractProtocol):
|
|||
return sample
|
||||
|
||||
|
||||
class OnLabelledCollectionProtocol:
|
||||
class OnLabelledCollectionProtocol(AbstractStochasticSeededProtocol):
|
||||
"""
|
||||
Protocols that generate samples from a :class:`qp.data.LabelledCollection` object.
|
||||
"""
|
||||
|
|
@ -229,8 +229,17 @@ class OnLabelledCollectionProtocol:
|
|||
elif return_type=='index':
|
||||
return lambda lc,params:params
|
||||
|
||||
def sample(self, index):
|
||||
"""
|
||||
Realizes the sample given the index of the instances.
|
||||
|
||||
class APP(AbstractStochasticSeededProtocol, OnLabelledCollectionProtocol):
|
||||
:param index: indexes of the instances to select
|
||||
:return: an instance of :class:`qp.data.LabelledCollection`
|
||||
"""
|
||||
return self.data.sampling_from_index(index)
|
||||
|
||||
|
||||
class APP(OnLabelledCollectionProtocol):
|
||||
"""
|
||||
Implementation of the artificial prevalence protocol (APP).
|
||||
The APP consists of exploring a grid of prevalence values containing `n_prevalences` points (e.g.,
|
||||
|
|
@ -311,15 +320,6 @@ class APP(AbstractStochasticSeededProtocol, OnLabelledCollectionProtocol):
|
|||
indexes.append(index)
|
||||
return indexes
|
||||
|
||||
def sample(self, index):
|
||||
"""
|
||||
Realizes the sample given the index of the instances.
|
||||
|
||||
:param index: indexes of the instances to select
|
||||
:return: an instance of :class:`qp.data.LabelledCollection`
|
||||
"""
|
||||
return self.data.sampling_from_index(index)
|
||||
|
||||
def total(self):
|
||||
"""
|
||||
Returns the number of samples that will be generated
|
||||
|
|
@ -329,7 +329,7 @@ class APP(AbstractStochasticSeededProtocol, OnLabelledCollectionProtocol):
|
|||
return F.num_prevalence_combinations(self.n_prevalences, self.data.n_classes, self.repeats)
|
||||
|
||||
|
||||
class NPP(AbstractStochasticSeededProtocol, OnLabelledCollectionProtocol):
|
||||
class NPP(OnLabelledCollectionProtocol):
|
||||
"""
|
||||
A generator of samples that implements the natural prevalence protocol (NPP). The NPP consists of drawing
|
||||
samples uniformly at random, therefore approximately preserving the natural prevalence of the collection.
|
||||
|
|
@ -365,15 +365,6 @@ class NPP(AbstractStochasticSeededProtocol, OnLabelledCollectionProtocol):
|
|||
indexes.append(index)
|
||||
return indexes
|
||||
|
||||
def sample(self, index):
|
||||
"""
|
||||
Realizes the sample given the index of the instances.
|
||||
|
||||
:param index: indexes of the instances to select
|
||||
:return: an instance of :class:`qp.data.LabelledCollection`
|
||||
"""
|
||||
return self.data.sampling_from_index(index)
|
||||
|
||||
def total(self):
|
||||
"""
|
||||
Returns the number of samples that will be generated (equals to "repeats")
|
||||
|
|
@ -383,7 +374,7 @@ class NPP(AbstractStochasticSeededProtocol, OnLabelledCollectionProtocol):
|
|||
return self.repeats
|
||||
|
||||
|
||||
class UPP(AbstractStochasticSeededProtocol, OnLabelledCollectionProtocol):
|
||||
class UPP(OnLabelledCollectionProtocol):
|
||||
"""
|
||||
A variant of :class:`APP` that, instead of using a grid of equidistant prevalence values,
|
||||
relies on the Kraemer algorithm for sampling unit (k-1)-simplex uniformly at random, with
|
||||
|
|
@ -423,14 +414,53 @@ class UPP(AbstractStochasticSeededProtocol, OnLabelledCollectionProtocol):
|
|||
indexes.append(index)
|
||||
return indexes
|
||||
|
||||
def sample(self, index):
|
||||
def total(self):
|
||||
"""
|
||||
Realizes the sample given the index of the instances.
|
||||
Returns the number of samples that will be generated (equals to "repeats")
|
||||
|
||||
:param index: indexes of the instances to select
|
||||
:return: an instance of :class:`qp.data.LabelledCollection`
|
||||
:return: int
|
||||
"""
|
||||
return self.data.sampling_from_index(index)
|
||||
return self.repeats
|
||||
|
||||
|
||||
class DirichletProtocol(OnLabelledCollectionProtocol):
|
||||
"""
|
||||
A protocol that establishes a prior Dirichlet distribution for the prevalence of the samples.
|
||||
Note that providing an all-ones vector of Dirichlet parameters is equivalent to invoking the
|
||||
APP protocol (although each protocol will generate a different series of samples given a
|
||||
fixed seed, since the implementation is different).
|
||||
|
||||
:param data: a `LabelledCollection` from which the samples will be drawn
|
||||
:param alpha: an array-like of shape (n_classes,) with the parameters of the Dirichlet distribution
|
||||
:param sample_size: integer, the number of instances in each sample; if None (default) then it is taken from
|
||||
qp.environ["SAMPLE_SIZE"]. If this is not set, a ValueError exception is raised.
|
||||
:param repeats: the number of samples to generate. Default is 100.
|
||||
:param random_state: allows replicating samples across runs (default 0, meaning that the sequence of samples
|
||||
will be the same every time the protocol is called)
|
||||
:param return_type: set to "sample_prev" (default) to get the pairs of (sample, prevalence) at each iteration, or
|
||||
to "labelled_collection" to get instead instances of LabelledCollection
|
||||
"""
|
||||
|
||||
def __init__(self, data: LabelledCollection, alpha, sample_size=None, repeats=100, random_state=0,
|
||||
return_type='sample_prev'):
|
||||
assert len(alpha)>1, 'wrong parameters: alpha must be an array-like of shape (n_classes,)'
|
||||
super(DirichletProtocol, self).__init__(random_state)
|
||||
self.data = data
|
||||
self.alpha = alpha
|
||||
self.sample_size = qp._get_sample_size(sample_size)
|
||||
self.repeats = repeats
|
||||
self.random_state = random_state
|
||||
self.collator = OnLabelledCollectionProtocol.get_collator(return_type)
|
||||
|
||||
def samples_parameters(self):
|
||||
"""
|
||||
Return all the necessary parameters to replicate the samples.
|
||||
|
||||
:return: a list of indexes that realize the sampling
|
||||
"""
|
||||
prevs = np.random.dirichlet(self.alpha, size=self.repeats)
|
||||
indexes = [self.data.sampling_index(self.sample_size, *prevs_i) for prevs_i in prevs]
|
||||
return indexes
|
||||
|
||||
def total(self):
|
||||
"""
|
||||
|
|
@ -450,7 +480,7 @@ class DomainMixer(AbstractStochasticSeededProtocol):
|
|||
:param sample_size: integer, the number of instances in each sample; if None (default) then it is taken from
|
||||
qp.environ["SAMPLE_SIZE"]. If this is not set, a ValueError exception is raised.
|
||||
:param repeats: int, number of samples to draw for every mixture rate
|
||||
:param prevalence: the prevalence to preserv along the mixtures. If specified, should be an array containing
|
||||
:param prevalence: the prevalence to preserve along the mixtures. If specified, should be an array containing
|
||||
one prevalence value (positive float) for each class and summing up to one. If not specified, the prevalence
|
||||
will be taken from the domain A (default).
|
||||
:param mixture_points: an integer indicating the number of points to take from a linear scale (e.g., 21 will
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ import unittest
|
|||
import numpy as np
|
||||
|
||||
import quapy.functional
|
||||
from protocol import DirichletProtocol
|
||||
from quapy.data import LabelledCollection
|
||||
from quapy.protocol import APP, NPP, UPP, DomainMixer, AbstractStochasticSeededProtocol
|
||||
|
||||
|
|
@ -138,6 +139,31 @@ class TestProtocols(unittest.TestCase):
|
|||
|
||||
self.assertNotEqual(samples1, samples2)
|
||||
|
||||
def test_dirichlet_replicate(self):
|
||||
data = mock_labelled_collection()
|
||||
p = DirichletProtocol(data, alpha=[1,2,3,4], sample_size=5, repeats=10, random_state=42)
|
||||
|
||||
samples1 = samples_to_str(p)
|
||||
samples2 = samples_to_str(p)
|
||||
|
||||
self.assertEqual(samples1, samples2)
|
||||
|
||||
p = DirichletProtocol(data, alpha=[1,2,3,4], sample_size=5, repeats=10, random_state=0)
|
||||
|
||||
samples1 = samples_to_str(p)
|
||||
samples2 = samples_to_str(p)
|
||||
|
||||
self.assertEqual(samples1, samples2)
|
||||
|
||||
def test_dirichlet_not_replicate(self):
|
||||
data = mock_labelled_collection()
|
||||
p = DirichletProtocol(data, alpha=[1,2,3,4], sample_size=5, repeats=10, random_state=None)
|
||||
|
||||
samples1 = samples_to_str(p)
|
||||
samples2 = samples_to_str(p)
|
||||
|
||||
self.assertNotEqual(samples1, samples2)
|
||||
|
||||
def test_covariate_shift_replicate(self):
|
||||
dataA = mock_labelled_collection('domA')
|
||||
dataB = mock_labelled_collection('domB')
|
||||
|
|
|
|||
Loading…
Reference in New Issue