From 2cc49083262f0a603fc2e6c8415ae2f7a422691b Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Pawe=C5=82=20Czy=C5=BC?= <pczyz@protonmail.com>
Date: Fri, 15 Mar 2024 14:01:24 +0100
Subject: [PATCH] Sketch of the Bayesian quantification

---
 quapy/functional.py         | 26 ++++++++---
 quapy/method/_bayesian.py   | 78 ++++++++++++++++++++++++++++++++
 quapy/method/aggregative.py | 89 ++++++++++++++++++++++++++++++++++++-
 setup.py                    |  7 ++-
 4 files changed, 188 insertions(+), 12 deletions(-)
 create mode 100644 quapy/method/_bayesian.py

diff --git a/quapy/functional.py b/quapy/functional.py
index c6dc351..3a4ebfa 100644
--- a/quapy/functional.py
+++ b/quapy/functional.py
@@ -28,22 +28,34 @@ def prevalence_linspace(n_prevalences=21, repeats=1, smooth_limits_epsilon=0.01)
     return p
 
 
-def prevalence_from_labels(labels, classes):
+def counts_from_labels(labels, classes):
     """
-    Computed the prevalence values from a vector of labels.
+    Computes the count values from a vector of labels.
 
-    :param labels: array-like of shape `(n_instances)` with the label for each instance
+    :param labels: array-like of shape `(n_instances,)` with the label for each instance
     :param classes: the class labels. This is needed in order to correctly compute the prevalence vector even when
         some classes have no examples.
-    :return: an ndarray of shape `(len(classes))` with the class prevalence values
+    :return: an ndarray of shape `(len(classes),)` with the occurrence counts of each class
     """
     if labels.ndim != 1:
         raise ValueError(f'param labels does not seem to be a ndarray of label predictions')
     unique, counts = np.unique(labels, return_counts=True)
     by_class = defaultdict(lambda:0, dict(zip(unique, counts)))
-    prevalences = np.asarray([by_class[class_] for class_ in classes], dtype=float)
-    prevalences /= prevalences.sum()
-    return prevalences
+    counts = np.asarray([by_class[class_] for class_ in classes], dtype=int)
+    return counts
+
+
+def prevalence_from_labels(labels, classes):
+    """
+    Computes the prevalence values from a vector of labels.
+
+    :param labels: array-like of shape `(n_instances,)` with the label for each instance
+    :param classes: the class labels. This is needed in order to correctly compute the prevalence vector even when
+        some classes have no examples.
+    :return: an ndarray of shape `(len(classes))` with the class prevalence values
+    """
+    counts = np.array(counts_from_labels(labels, classes), dtype=float)
+    return counts / np.sum(counts)
 
 
 def prevalence_from_probabilities(posteriors, binarize: bool = False):
diff --git a/quapy/method/_bayesian.py b/quapy/method/_bayesian.py
new file mode 100644
index 0000000..78a2c66
--- /dev/null
+++ b/quapy/method/_bayesian.py
@@ -0,0 +1,78 @@
+"""
+Utility functions for `Bayesian quantification <https://arxiv.org/abs/2302.09159>`_ methods.
+"""
+import numpy as np
+
+try:
+    import jax
+    import jax.numpy as jnp
+    import numpyro
+    import numpyro.distributions as dist
+
+    DEPENDENCIES_INSTALLED = True
+except ImportError:
+    jax = None
+    jnp = None
+    numpyro = None
+    dist = None
+
+    DEPENDENCIES_INSTALLED = False
+
+
+P_TEST_Y: str = "P_test(Y)"
+P_TEST_C: str = "P_test(C)"
+P_C_COND_Y: str = "P(C|Y)"
+
+
+def model(n_c_unlabeled: np.ndarray, n_y_and_c_labeled: np.ndarray) -> None:
+    """
+    Defines a probabilistic model in `NumPyro <https://num.pyro.ai/>`_.
+
+    :param n_c_unlabeled: a `np.ndarray` of shape `(n_predicted_classes,)`
+        with entry `c` being the number of instances predicted as class `c`.
+    :param n_y_and_c_labeled: a `np.ndarray` of shape `(n_classes, n_predicted_classes)`
+        with entry `(y, c)` being the number of instances labeled as class `y` and predicted as class `c`.
+    """
+    n_y_labeled = n_y_and_c_labeled.sum(axis=1)
+
+    K = len(n_c_unlabeled)
+    L = len(n_y_labeled)
+
+    pi_ = numpyro.sample(P_TEST_Y, dist.Dirichlet(jnp.ones(L)))
+    p_c_cond_y = numpyro.sample(P_C_COND_Y, dist.Dirichlet(jnp.ones(K).repeat(L).reshape(L, K)))
+
+    with numpyro.plate('plate', L):
+        numpyro.sample('F_yc', dist.Multinomial(n_y_labeled, p_c_cond_y), obs=n_y_and_c_labeled)
+
+    p_c = numpyro.deterministic(P_TEST_C, jnp.einsum("yc,y->c", p_c_cond_y, pi_))
+    numpyro.sample('N_c', dist.Multinomial(jnp.sum(n_c_unlabeled), p_c), obs=n_c_unlabeled)
+
+
+def sample_posterior(
+    n_c_unlabeled: np.ndarray,
+    n_y_and_c_labeled: np.ndarray,
+    num_warmup: int,
+    num_samples: int,
+    seed: int = 0,
+) -> dict:
+    """
+    Samples from the Bayesian quantification model in NumPyro using the
+    `NUTS <https://arxiv.org/abs/1111.4246>`_ sampler.
+
+    :param n_c_unlabeled: a `np.ndarray` of shape `(n_predicted_classes,)`
+        with entry `c` being the number of instances predicted as class `c`.
+    :param n_y_and_c_labeled: a `np.ndarray` of shape `(n_classes, n_predicted_classes)`
+        with entry `(y, c)` being the number of instances labeled as class `y` and predicted as class `c`.
+    :param num_warmup: the number of warmup steps.
+    :param num_samples: the number of samples to draw.
+    :seed: the random seed.
+    :return: a `dict` with the samples. The keys are the names of the latent variables.
+    """
+    mcmc = numpyro.infer.MCMC(
+        numpyro.infer.NUTS(model),
+        num_warmup=num_warmup,
+        num_samples=num_samples,
+    )
+    rng_key = jax.random.PRNGKey(seed)
+    mcmc.run(rng_key, n_c_unlabeled=n_c_unlabeled, n_y_and_c_labeled=n_y_and_c_labeled)
+    return mcmc.get_samples()
diff --git a/quapy/method/aggregative.py b/quapy/method/aggregative.py
index feeb5f2..a608053 100644
--- a/quapy/method/aggregative.py
+++ b/quapy/method/aggregative.py
@@ -11,6 +11,7 @@ from sklearn.model_selection import cross_val_predict
 
 import quapy as qp
 import quapy.functional as F
+import quapy._bayesian as _bayesian
 from quapy.functional import get_divergence
 from quapy.classification.calibration import NBVSCalibration, BCTSCalibration, TSCalibration, VSCalibration
 from quapy.classification.svmperf import SVMperf
@@ -384,7 +385,8 @@ class ACC(AggregativeCrispQuantifier):
         self.solver = solver
 
     def _check_init_parameters(self):
-        assert self.solver in ['exact', 'minimize'], "unknown solver; valid ones are 'exact', 'minimize'"
+        if self.solver not in ['exact', 'minimize']:
+            raise ValueError("unknown solver; valid ones are 'exact', 'minimize'")
 
     def aggregation_fit(self, classif_predictions: LabelledCollection, data: LabelledCollection):
         """
@@ -453,6 +455,91 @@ class ACC(AggregativeCrispQuantifier):
             return F.optim_minimize(loss, n_classes=A.shape[0])
 
 
+class BayesianCC(AggregativeCrispQuantifier):
+    """
+    `Bayesian quantification <https://arxiv.org/abs/2302.09159>`_ methods,
+    which is a variant of :class`ACC` that calculates the posterior probability distribution
+    over the prevalence vectors, rather than providing a point estimate obtained
+    by matrix inversion.
+
+    Can be used to diagnose degeneracy in the predictions visible when the confusion
+    matrix has high condition number or to quantify uncertainty around the point estimate.
+
+    This method relies on extra dependencies, which have to be installed via:
+    `$ pip install quapy[bayes]`
+
+    :param classifier: a sklearn's Estimator that generates a classifier
+    :param val_split: specifies the data used for generating classifier predictions. This specification
+        should be a float in (0, 1) indicating the proportion of stratified held-out validation set to
+        be extracted from the training set
+    :num_warmup: number of warmup iterations for the MCMC sampler
+    :num_samples: number of samples to draw from the posterior
+    :mcmc_seed: random seed for the MCMC sampler
+    """
+    def __init__(self, classifier: BaseEstimator, val_split: float = 0.75, num_warmup: int = 500, num_samples: int = 1_000, mcmc_seed: int = 0) -> None:
+        if num_warmup <= 0:
+            raise ValueError(f'num_warmup must be a positive integer, got {num_warmup}')
+        if num_samples <= 0:
+            raise ValueError(f'num_samples must be a positive integer, got {num_samples}')
+
+        if (not isinstance(val_split, float)) or val_split <= 0 or val_split >= 1:
+            raise ValueError(f'val_split must be a float in (0, 1), got {val_split}')
+
+        if _bayesian.DEPENDENCIES_INSTALLED is False:
+            raise ImportError("Auxiliary dependencies are required. Run `$ pip install quapy[bayes]` to install them.")
+
+        self.classifier = classifier
+        self.val_split = val_split
+        self.num_warmup = num_warmup
+        self.num_samples = num_samples
+        self.mcmc_seed = mcmc_seed
+
+        # Array of shape (n_classes, n_predicted_classes) where entry (y, c) is the number of instances labeled as class y and predicted as class c
+        # By default it's None and it's set during the `aggregation_fit` phase
+        self._n_and_c_labeled = None
+
+        # Dictionary with posterior samples, set when `aggregate` is provided.
+        self._samples = None
+
+    def aggregation_fit(self, classif_predictions: LabelledCollection, data: LabelledCollection):
+        """
+        Estimates the misclassification rates.
+
+        :param classif_predictions: classifier predictions with true labels
+        """
+        pred_labels, true_labels = classif_predictions.Xy
+        self._n_and_c_labeled = confusion_matrix(y_true=true_labels, y_pred=pred_labels, labels=self.classifier.classes_)
+
+    def sample_from_posterior(self, classif_predictions):
+        if self._n_and_c_labeled is None:
+            raise ValueError("aggregation_fit must be called before sample_from_posterior")
+
+        n_c_unlabeled = F.counts_from_labels(classif_predictions, self.classifier.classes_)
+
+        self._samples = _bayesian.sample_posterior(
+            n_c_unlabeled=n_c_unlabeled,
+            n_y_and_c_labeled=self._n_and_c_labeled,
+            num_warmup=self.num_warmup,
+            num_samples=self.num_samples,
+            seed=self.mcmc_seed,
+        )
+        return self._samples
+
+    def get_prevalence_samples(self):
+        if self._samples is None:
+            raise ValueError("sample_from_posterior must be called before get_prevalence_samples")
+        return self._samples[_bayesian.P_TEST_Y]
+
+    def get_conditional_probability_samples(self):
+        if self._samples is None:
+            raise ValueError("sample_from_posterior must be called before get_conditional_probability_samples")
+        return self._samples[_bayesian.P_C_COND_Y]
+
+    def aggregate(self, classif_predictions):
+        samples = self.sample_from_posterior(classif_predictions)[_bayesian.P_TEST_Y]
+        return np.asarray(samples.mean(axis=0), dtype=float)
+
+
 class PCC(AggregativeSoftQuantifier):
     """
     `Probabilistic Classify & Count <https://ieeexplore.ieee.org/abstract/document/5694031>`_,
diff --git a/setup.py b/setup.py
index 9ccb348..1f6c6fb 100644
--- a/setup.py
+++ b/setup.py
@@ -123,10 +123,9 @@ setup(
     #
     # Similar to `install_requires` above, these must be valid existing
     # projects.
-    # extras_require={  # Optional
-    #     'dev': ['check-manifest'],
-    #     'test': ['coverage'],
-    # },
+    extras_require={  # Optional
+       'bayes': ['jax', 'jaxlib', 'numpyro'],
+    },
 
     # If there are data files included in your packages that need to be
     # installed, specify them here.