forked from moreo/QuaPy
Refactor solving routine
This commit is contained in:
parent
4dd66b1921
commit
d34b086a76
|
@ -1,4 +1,5 @@
|
||||||
import itertools
|
import itertools
|
||||||
|
import warnings
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from typing import Literal, Union, Callable
|
from typing import Literal, Union, Callable
|
||||||
|
|
||||||
|
@ -426,3 +427,66 @@ def clip_prevalence(p: np.ndarray, method: Literal[None, "none", "clip", "projec
|
||||||
return _project_onto_probability_simplex(p)
|
return _project_onto_probability_simplex(p)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Method {method} not known.")
|
raise ValueError(f"Method {method} not known.")
|
||||||
|
|
||||||
|
|
||||||
|
def solve_adjustment(
|
||||||
|
p_c_y: np.ndarray,
|
||||||
|
p_c: np.ndarray,
|
||||||
|
method: Literal["inversion", "invariant-ratio"],
|
||||||
|
solver: Literal["exact", "minimize", "exact-raise", "exact-cc"],
|
||||||
|
) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Function finding the prevalence vector by adjusting
|
||||||
|
the classifier predictions.
|
||||||
|
|
||||||
|
:param p_c_y: array of shape `(n_classes, n_classes,)` with entry `(c,y)` being the estimate
|
||||||
|
of :math:`P(C=c|Y=y)`, that is, the probability that an instance that belongs to class :math:`y`
|
||||||
|
ends up being classified as belonging to class :math:`c`
|
||||||
|
:param p_c: classifier predictions, where the entry `c` is the estimate of :math:`P(C=c)`. Shape `(n_classes,)`
|
||||||
|
:param method: adjustment method to be used:
|
||||||
|
'inversion': matrix inversion method based on the matrix equality :math:`P(C)=P(C|Y)P(Y)`,
|
||||||
|
which tries to invert `P(C|Y)` matrix.
|
||||||
|
'invariant-ratio': invariant ratio estimator of `Vaz et al. <https://jmlr.org/papers/v20/18-456.html>`_,
|
||||||
|
which replaces the last equation with the normalization condition.
|
||||||
|
:param solver: the method to use for solving the system of linear equations. Valid options are:
|
||||||
|
'exact-raise': tries to solve the system using matrix inversion. Raises an error if the matrix has
|
||||||
|
rank strictly less than `n_classes`.
|
||||||
|
'exact-cc': if the matrix is not of full rank, returns `p_c` as the estimates, which corresponds
|
||||||
|
to no adjustment (i.e., the classify and count method. See :class:`quapy.method.aggregative.CC`)
|
||||||
|
'exact': deprecated, defaults to 'exact-cc'
|
||||||
|
'minimize': minimizes a loss, so the solution always exists
|
||||||
|
"""
|
||||||
|
if solver == "exact":
|
||||||
|
warnings.warn("The 'exact' solver is deprecated. Use 'exact-raise' or 'exact-cc'", DeprecationWarning, stacklevel=2)
|
||||||
|
solver = "exact-cc"
|
||||||
|
|
||||||
|
A = np.array(p_c_y, dtype=float)
|
||||||
|
B = np.array(p_c, dtype=float)
|
||||||
|
|
||||||
|
if method == "inversion":
|
||||||
|
pass # We leave A and B unchanged
|
||||||
|
elif method == "invariant-ratio":
|
||||||
|
# Change the last set of equations
|
||||||
|
raise NotImplementedError
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Flavour {method} not known.")
|
||||||
|
|
||||||
|
|
||||||
|
if solver == "minimize":
|
||||||
|
def loss(prev):
|
||||||
|
return np.linalg.norm(A @ prev - B)
|
||||||
|
return optim_minimize(loss, n_classes=A.shape[0])
|
||||||
|
else:
|
||||||
|
# Solvers based on matrix inversion, so we use try/except block
|
||||||
|
try:
|
||||||
|
return np.linalg.solve(A, B)
|
||||||
|
except np.linalg.LinAlgError:
|
||||||
|
# The matrix is not invertible.
|
||||||
|
# Depending on the solver, we either raise an error
|
||||||
|
# or return the classifier predictions without adjustment
|
||||||
|
if solver == "exact-raise":
|
||||||
|
raise
|
||||||
|
elif solver == "exact-cc":
|
||||||
|
return p_c
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Solver {solver} not known.")
|
||||||
|
|
|
@ -435,27 +435,13 @@ class ACC(AggregativeCrispQuantifier):
|
||||||
:return: an adjusted `np.ndarray` of shape `(n_classes,)` with the corrected class prevalence estimates
|
:return: an adjusted `np.ndarray` of shape `(n_classes,)` with the corrected class prevalence estimates
|
||||||
"""
|
"""
|
||||||
|
|
||||||
A = PteCondEstim
|
estimate = F.solve_adjustment(
|
||||||
B = prevs_estim
|
p_c_y=PteCondEstim,
|
||||||
|
p_c=prevs_estim,
|
||||||
if solver == 'exact':
|
solver=solver,
|
||||||
# attempts an exact solution of the linear system (may fail)
|
method='inversion',
|
||||||
|
)
|
||||||
try:
|
return F.clip_prevalence(estimate, method="clip")
|
||||||
adjusted_prevs = np.linalg.solve(A, B)
|
|
||||||
adjusted_prevs = F.clip_prevalence(adjusted_prevs, method="clip")
|
|
||||||
except np.linalg.LinAlgError:
|
|
||||||
adjusted_prevs = prevs_estim # no way to adjust them!
|
|
||||||
|
|
||||||
return adjusted_prevs
|
|
||||||
|
|
||||||
elif solver == 'minimize':
|
|
||||||
# poses the problem as an optimization one, and tries to minimize the norm of the differences
|
|
||||||
|
|
||||||
def loss(prev):
|
|
||||||
return np.linalg.norm(A @ prev - B)
|
|
||||||
|
|
||||||
return F.optim_minimize(loss, n_classes=A.shape[0])
|
|
||||||
|
|
||||||
|
|
||||||
class PCC(AggregativeSoftQuantifier):
|
class PCC(AggregativeSoftQuantifier):
|
||||||
|
|
Loading…
Reference in New Issue