From e33b29135713688615c26bfc163ce8d8512046b2 Mon Sep 17 00:00:00 2001 From: Alejandro Moreo Date: Wed, 7 Jan 2026 18:21:46 +0100 Subject: [PATCH] adding nuts to kdey --- BayesianKDEy/_bayeisan_kdey.py | 135 +++++++++++++++++++++++- BayesianKDEy/full_experiments.py | 12 ++- BayesianKDEy/generate_results.py | 30 ++++-- BayesianKDEy/single_experiment_debug.py | 4 +- CHANGE_LOG.txt | 3 +- quapy/data/datasets.py | 2 +- quapy/method/_bayesian.py | 5 +- quapy/protocol.py | 86 ++++++++++----- quapy/tests/test_protocols.py | 26 +++++ 9 files changed, 251 insertions(+), 52 deletions(-) diff --git a/BayesianKDEy/_bayeisan_kdey.py b/BayesianKDEy/_bayeisan_kdey.py index e20db79..69b6523 100644 --- a/BayesianKDEy/_bayeisan_kdey.py +++ b/BayesianKDEy/_bayeisan_kdey.py @@ -1,3 +1,5 @@ +from functools import lru_cache + from numpy.ma.core import shape from sklearn.base import BaseEstimator import numpy as np @@ -5,13 +7,19 @@ import numpy as np import quapy.util from quapy.method._kdey import KDEBase from quapy.method.confidence import WithConfidenceABC, ConfidenceRegionABC -from quapy.functional import CLRtransformation, ILRtransformation +from quapy.functional import CLRtransformation from quapy.method.aggregative import AggregativeSoftQuantifier from tqdm import tqdm import quapy.functional as F -#import emcee +import emcee +import jax +import jax.numpy as jnp +import numpyro +import numpyro.distributions as dist +from numpyro.infer import MCMC, NUTS + class BayesianKDEy(AggregativeSoftQuantifier, KDEBase, WithConfidenceABC): """ @@ -49,6 +57,7 @@ class BayesianKDEy(AggregativeSoftQuantifier, KDEBase, WithConfidenceABC): explore='simplex', step_size=0.05, temperature=1., + engine='numpyro', verbose: bool = False): if num_warmup <= 0: @@ -58,6 +67,7 @@ class BayesianKDEy(AggregativeSoftQuantifier, KDEBase, WithConfidenceABC): assert explore in ['simplex', 'clr', 'ilr'], \ f'unexpected value for param {explore=}; valid ones are "simplex", "clr", and "ilr"' assert temperature>0., f'temperature must be >0' + assert engine in ['rw-mh', 'emcee', 'numpyro'] super().__init__(classifier, fit_classifier, val_split) self.bandwidth = KDEBase._check_bandwidth(bandwidth, kernel) @@ -70,6 +80,7 @@ class BayesianKDEy(AggregativeSoftQuantifier, KDEBase, WithConfidenceABC): self.explore = explore self.step_size = step_size self.temperature = temperature + self.engine = engine self.verbose = verbose def aggregation_fit(self, classif_predictions, labels): @@ -77,8 +88,12 @@ class BayesianKDEy(AggregativeSoftQuantifier, KDEBase, WithConfidenceABC): return self def aggregate(self, classif_predictions): - # self.prevalence_samples = self._bayesian_kde(classif_predictions, init=None, verbose=self.verbose) - self.prevalence_samples = self._bayesian_emcee(classif_predictions) + if self.engine == 'rw-mh': + self.prevalence_samples = self._bayesian_kde(classif_predictions, init=None, verbose=self.verbose) + elif self.engine == 'emcee': + self.prevalence_samples = self._bayesian_emcee(classif_predictions) + elif self.engine == 'numpyro': + self.prevalence_samples = self._bayesian_numpyro(classif_predictions) return self.prevalence_samples.mean(axis=0) def predict_conf(self, instances, confidence_level=None) -> (np.ndarray, ConfidenceRegionABC): @@ -187,6 +202,8 @@ class BayesianKDEy(AggregativeSoftQuantifier, KDEBase, WithConfidenceABC): def _bayesian_emcee(self, X_probs): ndim = X_probs.shape[1] nwalkers = 32 + if nwalkers < (ndim*2): + nwalkers = ndim * 2 + 1 f = CLRtransformation() @@ -198,6 +215,7 @@ class BayesianKDEy(AggregativeSoftQuantifier, KDEBase, WithConfidenceABC): kdes = self.mix_densities test_densities = np.asarray([self.pdf(kde_i, X_probs, self.kernel) for kde_i in kdes]) + # test_densities_unconstrained = [f(t) for t in test_densities] # p0 = np.random.normal(nwalkers, ndim) p0 = F.uniform_prevalence_sampling(ndim, nwalkers) @@ -211,5 +229,114 @@ class BayesianKDEy(AggregativeSoftQuantifier, KDEBase, WithConfidenceABC): samples = f.inverse(samples) return samples + def _bayesian_numpyro(self, X_probs): + kdes = self.mix_densities + test_densities = np.asarray( + [self.pdf(kde_i, X_probs, self.kernel) for kde_i in kdes] + ) + + # move to jax + test_densities = jnp.array(test_densities) + + kernel = NUTS(self._numpyro_model) + mcmc = MCMC( + kernel, + num_warmup=self.num_warmup, + num_samples=self.num_samples, + num_chains=1, + progress_bar=self.verbose, + ) + + rng_key = jax.random.PRNGKey(self.mcmc_seed) + mcmc.run(rng_key, test_densities) + + samples_z = mcmc.get_samples()["z"] + + # back to simplex + ilr = ILRtransformation(jax_mode=True) + samples_prev = np.asarray(ilr.inverse(np.asarray(samples_z))) + + return samples_prev + + def _numpyro_model(self, test_densities): + """ + test_densities: shape (C, N) + """ + C = test_densities.shape[0] + ilr = ILRtransformation(jax_mode=True) + + # sample in unconstrained R^{C-1} + z = numpyro.sample( + "z", + dist.Normal(0.0, 1.0).expand([C - 1]) + ) + + prev = ilr.inverse(z) # simplex, shape (C,) + + # likelihood + likelihoods = jnp.dot(prev, test_densities) + numpyro.factor( + "loglik", + (1.0 / self.temperature) * jnp.sum(jnp.log(likelihoods + 1e-10)) + ) + + def in_simplex(x): return np.all(x >= 0) and np.isclose(x.sum(), 1) + + + + +class ILRtransformation(F.CompositionalTransformation): + def __init__(self, jax_mode=False): + self.jax_mode = jax_mode + + def array(self, X): + if self.jax_mode: + return jnp.array(X) + else: + return np.asarray(X) + + def __call__(self, X): + X = self.array(X) + X = quapy.error.smooth(X, self.EPSILON) + k = X.shape[-1] + V = self.array(self.get_V(k)) + logp = jnp.log(X) if self.jax_mode else np.log(X) + return logp @ V.T + + def inverse(self, Z): + Z = self.array(Z) + k_minus_1 = Z.shape[-1] + k = k_minus_1 + 1 + V = self.array(self.get_V(k)) + logp = Z @ V + p = jnp.exp(logp) if self.jax_mode else np.exp(logp) + p = p / jnp.sum(p, axis=-1, keepdims=True) if self.jax_mode else p / np.sum(p, axis=-1, keepdims=True) + return p + + @lru_cache(maxsize=None) + def get_V(self, k): + def helmert_matrix(k): + """ + Returns the (k x k) Helmert matrix. + """ + H = np.zeros((k, k)) + for i in range(1, k): + H[i, :i] = 1 + H[i, i] = -(i) + H[i] = H[i] / np.sqrt(i * (i + 1)) + # row 0 stays zeros; will be discarded + return H + + def ilr_basis(k): + """ + Constructs an orthonormal ILR basis using the Helmert submatrix. + Output shape: (k-1, k) + """ + H = helmert_matrix(k) + V = H[1:, :] # remove first row of zeros + return V + + return ilr_basis(k) + diff --git a/BayesianKDEy/full_experiments.py b/BayesianKDEy/full_experiments.py index 0a1dad2..4b0afa2 100644 --- a/BayesianKDEy/full_experiments.py +++ b/BayesianKDEy/full_experiments.py @@ -64,9 +64,9 @@ def methods(): only_multiclass = 'only_multiclass' # yield 'BootstrapACC', ACC(LR()), acc_hyper, lambda hyper: AggregativeBootstrap(ACC(LR()), n_test_samples=1000, random_state=0), multiclass_method - # yield 'BayesianACC', ACC(LR()), acc_hyper, lambda hyper: BayesianCC(LR(), mcmc_seed=0), multiclass_method + yield 'BayesianACC', ACC(LR()), acc_hyper, lambda hyper: BayesianCC(LR(), mcmc_seed=0), multiclass_method - yield 'BootstrapEMQ', EMQ(LR(), on_calib_error='backup', val_split=5), emq_hyper, lambda hyper: AggregativeBootstrap(EMQ(LR(), on_calib_error='backup', calib=hyper['calib'], val_split=5), n_test_samples=1000, random_state=0), multiclass_method + # yield 'BootstrapEMQ', EMQ(LR(), on_calib_error='backup', val_split=5), emq_hyper, lambda hyper: AggregativeBootstrap(EMQ(LR(), on_calib_error='backup', calib=hyper['calib'], val_split=5), n_test_samples=1000, random_state=0), multiclass_method # yield 'BootstrapHDy', DMy(LR()), hdy_hyper, lambda hyper: AggregativeBootstrap(DMy(LR(), **hyper), n_test_samples=1000, random_state=0), multiclass_method # yield 'BayesianHDy', DMy(LR()), hdy_hyper, lambda hyper: PQ(LR(), stan_seed=0, **hyper), only_binary @@ -78,6 +78,8 @@ def methods(): # yield 'BayKDEy*CLR2', KDEyCLR(LR()), kdey_hyper_clr, lambda hyper: BayesianKDEy(kernel='aitchison', mcmc_seed=0, explore='clr', step_size=.05, **hyper), multiclass_method # yield 'BayKDEy*ILR', KDEyCLR(LR()), kdey_hyper_clr, lambda hyper: BayesianKDEy(kernel='aitchison', mcmc_seed=0, explore='ilr', step_size=.15, **hyper), only_multiclass # yield 'BayKDEy*ILR2', KDEyILR(LR()), kdey_hyper_clr, lambda hyper: BayesianKDEy(kernel='ilr', mcmc_seed=0, explore='ilr', step_size=.1, **hyper), only_multiclass + yield f'BaKDE-emcee', KDEyML(LR()), kdey_hyper, lambda hyper: BayesianKDEy(mcmc_seed=0, num_warmup=100, num_samples=100, step_size=.1, engine='emcee', **hyper), multiclass_method + yield f'BaKDE-numpyro', KDEyML(LR()), kdey_hyper, lambda hyper: BayesianKDEy(mcmc_seed=0, step_size=.1, engine='numpyro', **hyper), multiclass_method def model_selection(train: LabelledCollection, point_quantifier: AggregativeQuantifier, grid: dict): @@ -153,18 +155,18 @@ if __name__ == '__main__': binary = { 'datasets': qp.datasets.UCI_BINARY_DATASETS, 'fetch_fn': qp.datasets.fetch_UCIBinaryDataset, - 'sample_size': 500 + 'sample_size': 100 # previous: 500 } multiclass = { 'datasets': qp.datasets.UCI_MULTICLASS_DATASETS, 'fetch_fn': qp.datasets.fetch_UCIMulticlassDataset, - 'sample_size': 1000 + 'sample_size': 200 # previous: 1000 } result_dir = Path('./results') - for setup in [binary, multiclass]: # [binary, multiclass]: + for setup in [multiclass]: # [binary, multiclass]: qp.environ['SAMPLE_SIZE'] = setup['sample_size'] for data_name in setup['datasets']: print(f'dataset={data_name}') diff --git a/BayesianKDEy/generate_results.py b/BayesianKDEy/generate_results.py index 04dedd9..5604af4 100644 --- a/BayesianKDEy/generate_results.py +++ b/BayesianKDEy/generate_results.py @@ -66,16 +66,21 @@ def update_pickle_with_region(report, file, conf_name, conf_region_class, **kwar if f'coverage-{conf_name}' not in report: covs, amps, winkler = compute_coverage_amplitude(conf_region_class, **kwargs) + # amperr (lower is better) counts the amplitude when the true vale was covered, or 1 (max amplitude) otherwise + amperrs = [amp if cov == 1.0 else 1. for amp, cov in zip(amps, covs)] + update_fields = { f'coverage-{conf_name}': covs, f'amplitude-{conf_name}': amps, - f'winkler-{conf_name}': winkler + f'winkler-{conf_name}': winkler, + f'amperr-{conf_name}': amperrs, } update_pickle(report, file, update_fields) -methods = None # show all methods -# methods = ['BayesianACC', 'BayesianKDEy'] + +# methods = None # show all methods +methods = ['BayesianACC', 'BayesianKDEy', 'BaKDE-emcee', 'BaKDE-numpyro'] for setup in ['multiclass']: path = f'./results/{setup}/*.pkl' @@ -103,22 +108,26 @@ for setup in ['multiclass']: table['c-CI'].extend(report['coverage-CI']) table['a-CI'].extend(report['amplitude-CI']) table['w-CI'].extend(report['winkler-CI']) + table['amperr-CI'].extend(report['amperr-CI']) table['c-CE'].extend(report['coverage-CE']) table['a-CE'].extend(report['amplitude-CE']) + table['amperr-CE'].extend(report['amperr-CE']) table['c-CLR'].extend(report['coverage-CLR']) table['a-CLR'].extend(report['amplitude-CLR']) + table['amperr-CLR'].extend(report['amperr-CLR']) table['c-ILR'].extend(report['coverage-ILR']) table['a-ILR'].extend(report['amplitude-ILR']) + table['amperr-ILR'].extend(report['amperr-ILR']) table['aitch'].extend(qp.error.dist_aitchison(results['true-prevs'], results['point-estim'])) # table['aitch-well'].extend(qp.error.dist_aitchison(results['true-prevs'], [ConfidenceEllipseILR(samples).mean_ for samples in results['samples']])) # table['aitch'].extend() - table['reg-score-ILR'].extend( - [region_score(true_prev, ConfidenceEllipseILR(samples)) for true_prev, samples in zip(results['true-prevs'], results['samples'])] - ) + # table['reg-score-ILR'].extend( + # [region_score(true_prev, ConfidenceEllipseILR(samples)) for true_prev, samples in zip(results['true-prevs'], results['samples'])] + # ) @@ -145,7 +154,7 @@ for setup in ['multiclass']: # if n < min_train: # df = df[df["dataset"] != data_name] - for region in ['ILR']: # , 'CI', 'CE', 'CLR', 'ILR']: + for region in ['CI', 'CE', 'CLR', 'ILR']: if setup == 'binary' and region=='ILR': continue # pv = pd.pivot_table( @@ -153,12 +162,15 @@ for setup in ['multiclass']: # ) pv = pd.pivot_table( df, index='dataset', columns='method', values=[ + f'amperr-{region}', + f'a-{region}', + f'c-{region}', #f'w-{region}', - # 'ae', + 'ae', # 'rae', # f'aitch', # f'aitch-well' - 'reg-score-ILR', + # 'reg-score-ILR', ], margins=True ) pv['n_classes'] = pv.index.map(n_classes).astype('Int64') diff --git a/BayesianKDEy/single_experiment_debug.py b/BayesianKDEy/single_experiment_debug.py index a33df67..7056fc7 100644 --- a/BayesianKDEy/single_experiment_debug.py +++ b/BayesianKDEy/single_experiment_debug.py @@ -68,8 +68,8 @@ if __name__ == '__main__': setup = multiclass # data_name = 'isolet' - # data_name = 'cmc' - data_name = 'abalone' + data_name = 'academic-success' + # data_name = 'abalone' qp.environ['SAMPLE_SIZE'] = setup['sample_size'] print(f'dataset={data_name}') diff --git a/CHANGE_LOG.txt b/CHANGE_LOG.txt index 9761c29..02b4166 100644 --- a/CHANGE_LOG.txt +++ b/CHANGE_LOG.txt @@ -1,13 +1,14 @@ Change Log 0.2.1 ----------------- +- Added DirichletProtocol, which allows to generate samples according to a parameterized Dirichlet prior. - Added squared ratio error. - Improved efficiency of confidence regions coverage functions - Added Precise Quantifier to WithConfidence methods (a Bayesian adaptation of HDy) - Improved documentation of confidence regions. - Added ReadMe method by Daniel Hopkins and Gary King - Internal index in LabelledCollection is now "lazy", and is only constructed if required. -- I have added dist_aitchison and mean_dist_aitchison as a new error evaluation metric +- Added dist_aitchison and mean_dist_aitchison as a new error evaluation metric. Change Log 0.2.0 ----------------- diff --git a/quapy/data/datasets.py b/quapy/data/datasets.py index c08748f..801b968 100644 --- a/quapy/data/datasets.py +++ b/quapy/data/datasets.py @@ -758,7 +758,7 @@ def fetch_UCIMulticlassLabelledCollection(dataset_name, data_home=None, min_clas # restrict classes to only those with at least min_ipc instances classes = classes[data.counts() >= min_ipc] # filter X and y keeping only datapoints belonging to valid classes - filter_idx = np.in1d(data.y, classes) + filter_idx = np.isin(data.y, classes) X, y = data.X[filter_idx], data.y[filter_idx] # map classes to range(len(classes)) y = np.searchsorted(classes, y) diff --git a/quapy/method/_bayesian.py b/quapy/method/_bayesian.py index 23507f7..6e75f65 100644 --- a/quapy/method/_bayesian.py +++ b/quapy/method/_bayesian.py @@ -33,7 +33,7 @@ P_TEST_C: str = "P_test(C)" P_C_COND_Y: str = "P(C|Y)" -def model(n_c_unlabeled: np.ndarray, n_y_and_c_labeled: np.ndarray) -> None: +def model_bayesianCC(n_c_unlabeled: np.ndarray, n_y_and_c_labeled: np.ndarray) -> None: """ Defines a probabilistic model in `NumPyro `_. @@ -57,6 +57,7 @@ def model(n_c_unlabeled: np.ndarray, n_y_and_c_labeled: np.ndarray) -> None: numpyro.sample('N_c', dist.Multinomial(jnp.sum(n_c_unlabeled), p_c), obs=n_c_unlabeled) + def sample_posterior( n_c_unlabeled: np.ndarray, n_y_and_c_labeled: np.ndarray, @@ -78,7 +79,7 @@ def sample_posterior( :return: a `dict` with the samples. The keys are the names of the latent variables. """ mcmc = numpyro.infer.MCMC( - numpyro.infer.NUTS(model), + numpyro.infer.NUTS(model_bayesianCC), num_warmup=num_warmup, num_samples=num_samples, progress_bar=False diff --git a/quapy/protocol.py b/quapy/protocol.py index 9a7e5c4..f43522f 100644 --- a/quapy/protocol.py +++ b/quapy/protocol.py @@ -171,7 +171,7 @@ class AbstractStochasticSeededProtocol(AbstractProtocol): return sample -class OnLabelledCollectionProtocol: +class OnLabelledCollectionProtocol(AbstractStochasticSeededProtocol): """ Protocols that generate samples from a :class:`qp.data.LabelledCollection` object. """ @@ -229,8 +229,17 @@ class OnLabelledCollectionProtocol: elif return_type=='index': return lambda lc,params:params + def sample(self, index): + """ + Realizes the sample given the index of the instances. -class APP(AbstractStochasticSeededProtocol, OnLabelledCollectionProtocol): + :param index: indexes of the instances to select + :return: an instance of :class:`qp.data.LabelledCollection` + """ + return self.data.sampling_from_index(index) + + +class APP(OnLabelledCollectionProtocol): """ Implementation of the artificial prevalence protocol (APP). The APP consists of exploring a grid of prevalence values containing `n_prevalences` points (e.g., @@ -311,15 +320,6 @@ class APP(AbstractStochasticSeededProtocol, OnLabelledCollectionProtocol): indexes.append(index) return indexes - def sample(self, index): - """ - Realizes the sample given the index of the instances. - - :param index: indexes of the instances to select - :return: an instance of :class:`qp.data.LabelledCollection` - """ - return self.data.sampling_from_index(index) - def total(self): """ Returns the number of samples that will be generated @@ -329,7 +329,7 @@ class APP(AbstractStochasticSeededProtocol, OnLabelledCollectionProtocol): return F.num_prevalence_combinations(self.n_prevalences, self.data.n_classes, self.repeats) -class NPP(AbstractStochasticSeededProtocol, OnLabelledCollectionProtocol): +class NPP(OnLabelledCollectionProtocol): """ A generator of samples that implements the natural prevalence protocol (NPP). The NPP consists of drawing samples uniformly at random, therefore approximately preserving the natural prevalence of the collection. @@ -365,15 +365,6 @@ class NPP(AbstractStochasticSeededProtocol, OnLabelledCollectionProtocol): indexes.append(index) return indexes - def sample(self, index): - """ - Realizes the sample given the index of the instances. - - :param index: indexes of the instances to select - :return: an instance of :class:`qp.data.LabelledCollection` - """ - return self.data.sampling_from_index(index) - def total(self): """ Returns the number of samples that will be generated (equals to "repeats") @@ -383,7 +374,7 @@ class NPP(AbstractStochasticSeededProtocol, OnLabelledCollectionProtocol): return self.repeats -class UPP(AbstractStochasticSeededProtocol, OnLabelledCollectionProtocol): +class UPP(OnLabelledCollectionProtocol): """ A variant of :class:`APP` that, instead of using a grid of equidistant prevalence values, relies on the Kraemer algorithm for sampling unit (k-1)-simplex uniformly at random, with @@ -423,14 +414,53 @@ class UPP(AbstractStochasticSeededProtocol, OnLabelledCollectionProtocol): indexes.append(index) return indexes - def sample(self, index): + def total(self): """ - Realizes the sample given the index of the instances. + Returns the number of samples that will be generated (equals to "repeats") - :param index: indexes of the instances to select - :return: an instance of :class:`qp.data.LabelledCollection` + :return: int """ - return self.data.sampling_from_index(index) + return self.repeats + + +class DirichletProtocol(OnLabelledCollectionProtocol): + """ + A protocol that establishes a prior Dirichlet distribution for the prevalence of the samples. + Note that providing an all-ones vector of Dirichlet parameters is equivalent to invoking the + APP protocol (although each protocol will generate a different series of samples given a + fixed seed, since the implementation is different). + + :param data: a `LabelledCollection` from which the samples will be drawn + :param alpha: an array-like of shape (n_classes,) with the parameters of the Dirichlet distribution + :param sample_size: integer, the number of instances in each sample; if None (default) then it is taken from + qp.environ["SAMPLE_SIZE"]. If this is not set, a ValueError exception is raised. + :param repeats: the number of samples to generate. Default is 100. + :param random_state: allows replicating samples across runs (default 0, meaning that the sequence of samples + will be the same every time the protocol is called) + :param return_type: set to "sample_prev" (default) to get the pairs of (sample, prevalence) at each iteration, or + to "labelled_collection" to get instead instances of LabelledCollection + """ + + def __init__(self, data: LabelledCollection, alpha, sample_size=None, repeats=100, random_state=0, + return_type='sample_prev'): + assert len(alpha)>1, 'wrong parameters: alpha must be an array-like of shape (n_classes,)' + super(DirichletProtocol, self).__init__(random_state) + self.data = data + self.alpha = alpha + self.sample_size = qp._get_sample_size(sample_size) + self.repeats = repeats + self.random_state = random_state + self.collator = OnLabelledCollectionProtocol.get_collator(return_type) + + def samples_parameters(self): + """ + Return all the necessary parameters to replicate the samples. + + :return: a list of indexes that realize the sampling + """ + prevs = np.random.dirichlet(self.alpha, size=self.repeats) + indexes = [self.data.sampling_index(self.sample_size, *prevs_i) for prevs_i in prevs] + return indexes def total(self): """ @@ -450,7 +480,7 @@ class DomainMixer(AbstractStochasticSeededProtocol): :param sample_size: integer, the number of instances in each sample; if None (default) then it is taken from qp.environ["SAMPLE_SIZE"]. If this is not set, a ValueError exception is raised. :param repeats: int, number of samples to draw for every mixture rate - :param prevalence: the prevalence to preserv along the mixtures. If specified, should be an array containing + :param prevalence: the prevalence to preserve along the mixtures. If specified, should be an array containing one prevalence value (positive float) for each class and summing up to one. If not specified, the prevalence will be taken from the domain A (default). :param mixture_points: an integer indicating the number of points to take from a linear scale (e.g., 21 will diff --git a/quapy/tests/test_protocols.py b/quapy/tests/test_protocols.py index 4850bd4..e1ee85f 100644 --- a/quapy/tests/test_protocols.py +++ b/quapy/tests/test_protocols.py @@ -2,6 +2,7 @@ import unittest import numpy as np import quapy.functional +from protocol import DirichletProtocol from quapy.data import LabelledCollection from quapy.protocol import APP, NPP, UPP, DomainMixer, AbstractStochasticSeededProtocol @@ -138,6 +139,31 @@ class TestProtocols(unittest.TestCase): self.assertNotEqual(samples1, samples2) + def test_dirichlet_replicate(self): + data = mock_labelled_collection() + p = DirichletProtocol(data, alpha=[1,2,3,4], sample_size=5, repeats=10, random_state=42) + + samples1 = samples_to_str(p) + samples2 = samples_to_str(p) + + self.assertEqual(samples1, samples2) + + p = DirichletProtocol(data, alpha=[1,2,3,4], sample_size=5, repeats=10, random_state=0) + + samples1 = samples_to_str(p) + samples2 = samples_to_str(p) + + self.assertEqual(samples1, samples2) + + def test_dirichlet_not_replicate(self): + data = mock_labelled_collection() + p = DirichletProtocol(data, alpha=[1,2,3,4], sample_size=5, repeats=10, random_state=None) + + samples1 = samples_to_str(p) + samples2 = samples_to_str(p) + + self.assertNotEqual(samples1, samples2) + def test_covariate_shift_replicate(self): dataA = mock_labelled_collection('domA') dataB = mock_labelled_collection('domB')