From d34b086a767147c81e7db9312fd206c7936a9cd9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Czy=C5=BC?= Date: Fri, 15 Mar 2024 17:58:23 +0100 Subject: [PATCH] Refactor solving routine --- quapy/functional.py | 64 +++++++++++++++++++++++++++++++++++++ quapy/method/aggregative.py | 28 ++++------------ 2 files changed, 71 insertions(+), 21 deletions(-) diff --git a/quapy/functional.py b/quapy/functional.py index 1459a0f..eb4485e 100644 --- a/quapy/functional.py +++ b/quapy/functional.py @@ -1,4 +1,5 @@ import itertools +import warnings from collections import defaultdict from typing import Literal, Union, Callable @@ -426,3 +427,66 @@ def clip_prevalence(p: np.ndarray, method: Literal[None, "none", "clip", "projec return _project_onto_probability_simplex(p) else: raise ValueError(f"Method {method} not known.") + + +def solve_adjustment( + p_c_y: np.ndarray, + p_c: np.ndarray, + method: Literal["inversion", "invariant-ratio"], + solver: Literal["exact", "minimize", "exact-raise", "exact-cc"], +) -> np.ndarray: + """ + Function finding the prevalence vector by adjusting + the classifier predictions. + + :param p_c_y: array of shape `(n_classes, n_classes,)` with entry `(c,y)` being the estimate + of :math:`P(C=c|Y=y)`, that is, the probability that an instance that belongs to class :math:`y` + ends up being classified as belonging to class :math:`c` + :param p_c: classifier predictions, where the entry `c` is the estimate of :math:`P(C=c)`. Shape `(n_classes,)` + :param method: adjustment method to be used: + 'inversion': matrix inversion method based on the matrix equality :math:`P(C)=P(C|Y)P(Y)`, + which tries to invert `P(C|Y)` matrix. + 'invariant-ratio': invariant ratio estimator of `Vaz et al. `_, + which replaces the last equation with the normalization condition. + :param solver: the method to use for solving the system of linear equations. Valid options are: + 'exact-raise': tries to solve the system using matrix inversion. Raises an error if the matrix has + rank strictly less than `n_classes`. + 'exact-cc': if the matrix is not of full rank, returns `p_c` as the estimates, which corresponds + to no adjustment (i.e., the classify and count method. See :class:`quapy.method.aggregative.CC`) + 'exact': deprecated, defaults to 'exact-cc' + 'minimize': minimizes a loss, so the solution always exists + """ + if solver == "exact": + warnings.warn("The 'exact' solver is deprecated. Use 'exact-raise' or 'exact-cc'", DeprecationWarning, stacklevel=2) + solver = "exact-cc" + + A = np.array(p_c_y, dtype=float) + B = np.array(p_c, dtype=float) + + if method == "inversion": + pass # We leave A and B unchanged + elif method == "invariant-ratio": + # Change the last set of equations + raise NotImplementedError + else: + raise ValueError(f"Flavour {method} not known.") + + + if solver == "minimize": + def loss(prev): + return np.linalg.norm(A @ prev - B) + return optim_minimize(loss, n_classes=A.shape[0]) + else: + # Solvers based on matrix inversion, so we use try/except block + try: + return np.linalg.solve(A, B) + except np.linalg.LinAlgError: + # The matrix is not invertible. + # Depending on the solver, we either raise an error + # or return the classifier predictions without adjustment + if solver == "exact-raise": + raise + elif solver == "exact-cc": + return p_c + else: + raise ValueError(f"Solver {solver} not known.") diff --git a/quapy/method/aggregative.py b/quapy/method/aggregative.py index 5ea0473..77a4eaf 100644 --- a/quapy/method/aggregative.py +++ b/quapy/method/aggregative.py @@ -435,27 +435,13 @@ class ACC(AggregativeCrispQuantifier): :return: an adjusted `np.ndarray` of shape `(n_classes,)` with the corrected class prevalence estimates """ - A = PteCondEstim - B = prevs_estim - - if solver == 'exact': - # attempts an exact solution of the linear system (may fail) - - try: - adjusted_prevs = np.linalg.solve(A, B) - adjusted_prevs = F.clip_prevalence(adjusted_prevs, method="clip") - except np.linalg.LinAlgError: - adjusted_prevs = prevs_estim # no way to adjust them! - - return adjusted_prevs - - elif solver == 'minimize': - # poses the problem as an optimization one, and tries to minimize the norm of the differences - - def loss(prev): - return np.linalg.norm(A @ prev - B) - - return F.optim_minimize(loss, n_classes=A.shape[0]) + estimate = F.solve_adjustment( + p_c_y=PteCondEstim, + p_c=prevs_estim, + solver=solver, + method='inversion', + ) + return F.clip_prevalence(estimate, method="clip") class PCC(AggregativeSoftQuantifier):