ilr added

This commit is contained in:
Alejandro Moreo Fernandez 2025-12-06 13:20:11 +01:00
parent 59ef17c86c
commit d52fc40d2b
7 changed files with 190 additions and 159 deletions

View File

@ -1,7 +1,8 @@
from sklearn.base import BaseEstimator
import numpy as np
from quapy.method._kdey import KDEBase
from quapy.method.confidence import WithConfidenceABC, ConfidenceRegionABC, CLRtransformation, ILRtransformation
from quapy.method.confidence import WithConfidenceABC, ConfidenceRegionABC
from functional import CLRtransformation, ILRtransformation
from quapy.method.aggregative import AggregativeSoftQuantifier
from tqdm import tqdm
import quapy.functional as F

View File

@ -36,6 +36,14 @@ class KDEyCLR(KDEyML):
)
class KDEyILR(KDEyML):
def __init__(self, classifier: BaseEstimator=None, fit_classifier=True, val_split=5, bandwidth=1., random_state=None):
super().__init__(
classifier=classifier, fit_classifier=fit_classifier, val_split=val_split, bandwidth=bandwidth,
random_state=random_state, kernel='ilr'
)
def methods():
"""
Returns a tuple (name, quantifier, hyperparams, bayesian/bootstrap_constructor), where:
@ -66,6 +74,7 @@ def methods():
yield 'BayKDEy*CLR', KDEyCLR(LR()), kdey_hyper_clr, lambda hyper: BayesianKDEy(kernel='aitchison', mcmc_seed=0, explore='clr', step_size=.15, **hyper), multiclass_method
# yield 'BayKDEy*CLR2', KDEyCLR(LR()), kdey_hyper_clr, lambda hyper: BayesianKDEy(kernel='aitchison', mcmc_seed=0, explore='clr', step_size=.05, **hyper), multiclass_method
yield 'BayKDEy*ILR', KDEyCLR(LR()), kdey_hyper_clr, lambda hyper: BayesianKDEy(kernel='aitchison', mcmc_seed=0, explore='ilr', step_size=.15, **hyper), only_multiclass
yield 'BayKDEy*ILR2', KDEyILR(LR()), kdey_hyper_clr, lambda hyper: BayesianKDEy(kernel='ilr', mcmc_seed=0, explore='ilr', step_size=.1, **hyper), only_multiclass
def model_selection(train: LabelledCollection, point_quantifier: AggregativeQuantifier, grid: dict):

View File

@ -7,13 +7,37 @@ import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
from scipy.stats import gaussian_kde
from method.confidence import ConfidenceIntervals, ConfidenceEllipseSimplex, ConfidenceEllipseCLR, ConfidenceEllipseILR
from method.confidence import (ConfidenceIntervals as CI,
ConfidenceEllipseSimplex as CE,
ConfidenceEllipseCLR as CLR,
ConfidenceEllipseILR as ILR)
def get_region_colormap(name="blue", alpha=0.40):
name = name.lower()
if name == "blue":
base = (76/255, 114/255, 176/255)
elif name == "orange":
base = (221/255, 132/255, 82/255)
elif name == "violet":
base = (129/255, 114/255, 178/255)
else:
raise ValueError(f"Unknown palette name: {name}")
cmap = ListedColormap([
(1, 1, 1, 0), # 0: transparent white
(base[0], base[1], base[2], alpha) # 1: color
])
return cmap
def plot_prev_points(prevs=None, true_prev=None, point_estim=None, train_prev=None, show_mean=True, show_legend=True,
region=None,
region_resolution=1000,
confine_region_in_simplex=False,
color='blue',
save_path=None):
plt.rcParams.update({
@ -49,13 +73,22 @@ def plot_prev_points(prevs=None, true_prev=None, point_estim=None, train_prev=No
# Plot
fig, ax = plt.subplots(figsize=(6, 6))
if region is not None:
if callable(region):
region_list = [("region", region)]
else:
region_list = region # lista de (name, fn)
if region is not None:
# rectangular mesh
xs = np.linspace(0, 1, region_resolution)
ys = np.linspace(0, np.sqrt(3)/2, region_resolution)
x_min, x_max = -0.2, 1.2
y_min, y_max = -0.2, np.sqrt(3) / 2 + 0.2
xs = np.linspace(x_min, x_max, region_resolution)
ys = np.linspace(y_min, y_max, region_resolution)
grid_x, grid_y = np.meshgrid(xs, ys)
# 2 barycentric
# barycentric
pts_bary = barycentric_from_xy(grid_x, grid_y)
# mask within simplex
@ -64,29 +97,23 @@ def plot_prev_points(prevs=None, true_prev=None, point_estim=None, train_prev=No
else:
in_simplex = np.full(shape=(region_resolution, region_resolution), fill_value=True, dtype=bool)
# evaluar la región solo en puntos válidos
# --- Colormap 0 → blanco, 1 → rojo semitransparente ---
# iterar sobre todas las regiones
for (rname, rfun) in region_list:
mask = np.zeros_like(in_simplex, dtype=float)
valid_pts = pts_bary[in_simplex]
mask_vals = np.array([float(region(p)) for p in valid_pts])
mask_vals = np.array([float(rfun(p)) for p in valid_pts])
mask[in_simplex] = mask_vals
# pintar el fondo
white_and_color = ListedColormap([
(1, 1, 1, 1), # color for value 0
(0.7, .0, .0, .5) # color for value 1
])
ax.pcolormesh(
xs, ys, mask,
shading='auto',
cmap=white_and_color,
alpha=0.5
cmap=get_region_colormap(color),
alpha=0.3,
)
ax.scatter(*cartesian(prevs), s=15, alpha=0.5, edgecolors='none', label='samples')
ax.scatter(*cartesian(prevs), s=15, alpha=0.5, edgecolors='none', label='samples', color='black', linewidth=0.5)
if show_mean:
ax.scatter(*cartesian(prevs.mean(axis=0)), s=10, alpha=1, label='sample-mean', edgecolors='black')
if train_prev is not None:
@ -96,8 +123,6 @@ def plot_prev_points(prevs=None, true_prev=None, point_estim=None, train_prev=No
if train_prev is not None:
ax.scatter(*cartesian(train_prev), s=10, alpha=1, label='train-prev', edgecolors='black')
# edges
triangle = np.array([v1, v2, v3, v1])
ax.plot(triangle[:, 0], triangle[:, 1], color='black')
@ -179,17 +204,21 @@ if __name__ == '__main__':
n = 1000
alpha = [3,5,10]
# alpha = [10,1,1]
prevs = np.random.dirichlet(alpha, size=n)
def regions():
yield 'CI', ConfidenceIntervals(prevs)
yield 'CE', ConfidenceEllipseSimplex(prevs)
yield 'CLR', ConfidenceEllipseCLR(prevs)
yield 'ILR', ConfidenceEllipseILR(prevs)
confs = [0.99, 0.95, 0.90]
yield 'CI', [(f'{int(c*100)}%', CI(prevs, confidence_level=c).coverage) for c in confs]
yield 'CI-b', [(f'{int(c * 100)}%', CI(prevs, confidence_level=c, bonferroni_correction=True).coverage) for c in confs]
yield 'CE', [(f'{int(c*100)}%', CE(prevs, confidence_level=c).coverage) for c in confs]
yield 'CLR', [(f'{int(c*100)}%', CLR(prevs, confidence_level=c).coverage) for c in confs]
yield 'ILR', [(f'{int(c*100)}%', ILR(prevs, confidence_level=c).coverage) for c in confs]
resolution = 100
resolution = 1000
alpha_str = ','.join([f'{str(i)}' for i in alpha])
for crname, cr in regions():
plot_prev_points(prevs, show_mean=True, show_legend=False, region=cr.coverage, region_resolution=resolution,
plot_prev_points(prevs, show_mean=True, show_legend=False, region=cr, region_resolution=resolution,
color='blue',
save_path=f'./plots/simplex_{crname}_alpha{alpha_str}_res{resolution}.png')

View File

@ -1,10 +1,15 @@
import warnings
from abc import ABC, abstractmethod
from collections import defaultdict
from functools import lru_cache
from typing import Literal, Union, Callable
from numpy.typing import ArrayLike
import scipy
import numpy as np
from scipy.special import softmax
import quapy as qp
# ------------------------------------------------------------------------------------------
@ -649,3 +654,96 @@ def solve_adjustment(
raise ValueError(f'unknown {solver=}')
# ------------------------------------------------------------------------------------------
# Transformations from Compositional analysis
# ------------------------------------------------------------------------------------------
class CompositionalTransformation(ABC):
"""
Abstract class of transformations from compositional data.
Basically, callable functions with an "inverse" function.
"""
@abstractmethod
def __call__(self, X): ...
@abstractmethod
def inverse(self, Z): ...
EPSILON=1e-6
class CLRtransformation(CompositionalTransformation):
"""
Centered log-ratio (CLR), from compositional analysis
"""
def __call__(self, X):
"""
Applies the CLR function to X thus mapping the instances, which are contained in `\\mathcal{R}^{n}` but
actually lie on a `\\mathcal{R}^{n-1}` simplex, onto an unrestricted space in :math:`\\mathcal{R}^{n}`
:param X: np.ndarray of (n_instances, n_dimensions) to be transformed
:param epsilon: small float for prevalence smoothing
:return: np.ndarray of (n_instances, n_dimensions), the CLR-transformed points
"""
X = np.asarray(X)
X = qp.error.smooth(X, self.EPSILON)
G = np.exp(np.mean(np.log(X), axis=-1, keepdims=True)) # geometric mean
return np.log(X / G)
def inverse(self, Z):
"""
Inverse function. However, clr.inverse(clr(X)) does not exactly coincide with X due to smoothing.
:param Z: np.ndarray of (n_instances, n_dimensions) to be transformed
:return: np.ndarray of (n_instances, n_dimensions), the CLR-transformed points
"""
return softmax(Z, axis=-1)
class ILRtransformation(CompositionalTransformation):
"""
Isometric log-ratio (ILR), from compositional analysis
"""
def __call__(self, X):
X = np.asarray(X)
X = qp.error.smooth(X, self.EPSILON)
k = X.shape[-1]
V = self.get_V(k) # (k-1, k)
logp = np.log(X)
return logp @ V.T
def inverse(self, Z):
Z = np.asarray(Z)
# get dimension
k_minus_1 = Z.shape[-1]
k = k_minus_1 + 1
V = self.get_V(k) # (k-1, k)
logp = Z @ V
p = np.exp(logp)
p = p / np.sum(p, axis=-1, keepdims=True)
return p
@lru_cache(maxsize=None)
def get_V(self, k):
def helmert_matrix(k):
"""
Returns the (k x k) Helmert matrix.
"""
H = np.zeros((k, k))
for i in range(1, k):
H[i, :i] = 1
H[i, i] = -(i)
H[i] = H[i] / np.sqrt(i * (i + 1))
# row 0 stays zeros; will be discarded
return H
def ilr_basis(k):
"""
Constructs an orthonormal ILR basis using the Helmert submatrix.
Output shape: (k-1, k)
"""
H = helmert_matrix(k)
V = H[1:, :] # remove first row of zeros
return V
return ilr_basis(k)

View File

@ -9,27 +9,13 @@ import quapy.functional as F
from sklearn.metrics.pairwise import rbf_kernel
# class KDE(KernelDensity):
#
# KERNELS = ['gaussian', 'aitchison']
#
# def __init__(self, bandwidth, kernel):
# assert kernel in KDE.KERNELS, f'unknown {kernel=}'
# self.bandwidth = bandwidth
# self.kernel = kernel
#
# def
class KDEBase:
"""
Common ancestor for KDE-based methods. Implements some common routines.
"""
BANDWIDTH_METHOD = ['scott', 'silverman']
KERNELS = ['gaussian', 'aitchison']
KERNELS = ['gaussian', 'aitchison', 'ilr']
@classmethod
@ -59,18 +45,6 @@ class KDEBase:
assert kernel in KDEBase.KERNELS, f'unknown {kernel=}'
return kernel
@classmethod
def clr_transform(cls, P, eps=1e-7):
"""
Centered-Log Ratio (CLR) transform.
P: array (n_samples, n_classes), every row is a point in the probability simplex
eps: smoothing, to avoid log(0)
"""
X_safe = np.clip(P, eps, None)
X_safe /= X_safe.sum(axis=1, keepdims=True) # renormalize
gm = np.exp(np.mean(np.log(X_safe), axis=1, keepdims=True))
return np.log(X_safe / gm)
def get_kde_function(self, X, bandwidth, kernel):
"""
Wraps the KDE function from scikit-learn.
@ -81,7 +55,9 @@ class KDEBase:
:return: a scikit-learn's KernelDensity object
"""
if kernel == 'aitchison':
X = KDEBase.clr_transform(X)
X = self.clr_transform(X)
elif kernel == 'ilr':
X = self.ilr_transform(X)
return KernelDensity(bandwidth=bandwidth).fit(X)
@ -96,7 +72,9 @@ class KDEBase:
:return: np.ndarray with the densities
"""
if kernel == 'aitchison':
X = KDEBase.clr_transform(X)
X = self.clr_transform(X)
elif kernel == 'ilr':
X = self.ilr_transform(X)
return np.exp(kde.score_samples(X))
@ -117,12 +95,19 @@ class KDEBase:
if selX.size==0:
selX = [F.uniform_prevalence(len(classes))]
# if kernel == 'aitchison':
# this is already done within get_kde_function
# selX = KDEBase.clr_transform(selX)
class_cond_X.append(selX)
return [self.get_kde_function(X_cond_yi, bandwidth, kernel) for X_cond_yi in class_cond_X]
def clr_transform(self, X):
if not hasattr(self, 'clr'):
self.clr = F.CLRtransformation()
return self.clr(X)
def ilr_transform(self, X):
if not hasattr(self, 'ilr'):
self.ilr = F.ILRtransformation()
return self.ilr(X)
class KDEyML(AggregativeSoftQuantifier, KDEBase):
"""

View File

@ -673,7 +673,7 @@ class PACC(AggregativeSoftQuantifier):
class EMQ(AggregativeSoftQuantifier):
"""
`Expectation Maximization for Quantification <https://ieeexplore.ieee.org/abstract/document/6789744>`_ (EMQ),
aka `Saerens-Latinne-Decaestecker` (SLD) algorithm.
aka `Saerens-Latinne-Decaestecker` (SLD) algorithm, or `Maximum Likelihood Label Shif` (MLLS).
EMQ consists of using the well-known `Expectation Maximization algorithm` to iteratively update the posterior
probabilities generated by a probabilistic classifier and the class prevalence estimates obtained via
maximum-likelihood estimation, in a mutually recursive way, until convergence.

View File

@ -5,13 +5,14 @@ from sklearn.metrics import confusion_matrix
import quapy as qp
import quapy.functional as F
from functional import CompositionalTransformation, CLRtransformation, ILRtransformation
from quapy.method import _bayesian
from quapy.data import LabelledCollection
from quapy.method.aggregative import AggregativeQuantifier, AggregativeCrispQuantifier, AggregativeSoftQuantifier, BinaryAggregativeQuantifier
from scipy.stats import chi2
from sklearn.utils import resample
from abc import ABC, abstractmethod
from scipy.special import softmax, factorial
from scipy.special import factorial
import copy
from functools import lru_cache
from tqdm import tqdm
@ -218,98 +219,6 @@ def within_ellipse_prop(values, mean, prec_matrix, chi2_critical):
return float(np.mean(within_ellipse))
class CompositionalTransformation(ABC):
"""
Abstract class of transformations from compositional data.
Basically, callable functions with an "inverse" function.
"""
@abstractmethod
def __call__(self, X): ...
@abstractmethod
def inverse(self, Z): ...
EPSILON=1e-6
class CLRtransformation(CompositionalTransformation):
"""
Centered log-ratio (CLR), from compositional analysis
"""
def __call__(self, X):
"""
Applies the CLR function to X thus mapping the instances, which are contained in `\\mathcal{R}^{n}` but
actually lie on a `\\mathcal{R}^{n-1}` simplex, onto an unrestricted space in :math:`\\mathcal{R}^{n}`
:param X: np.ndarray of (n_instances, n_dimensions) to be transformed
:param epsilon: small float for prevalence smoothing
:return: np.ndarray of (n_instances, n_dimensions), the CLR-transformed points
"""
X = np.asarray(X)
X = qp.error.smooth(X, self.EPSILON)
G = np.exp(np.mean(np.log(X), axis=-1, keepdims=True)) # geometric mean
return np.log(X / G)
def inverse(self, Z):
"""
Inverse function. However, clr.inverse(clr(X)) does not exactly coincide with X due to smoothing.
:param Z: np.ndarray of (n_instances, n_dimensions) to be transformed
:return: np.ndarray of (n_instances, n_dimensions), the CLR-transformed points
"""
return softmax(Z, axis=-1)
class ILRtransformation(CompositionalTransformation):
"""
Isometric log-ratio (ILR), from compositional analysis
"""
def __call__(self, X):
X = np.asarray(X)
X = qp.error.smooth(X, self.EPSILON)
k = X.shape[-1]
V = self.get_V(k) # (k-1, k)
logp = np.log(X)
return logp @ V.T
def inverse(self, Z):
Z = np.asarray(Z)
# get dimension
k_minus_1 = Z.shape[-1]
k = k_minus_1 + 1
V = self.get_V(k) # (k-1, k)
logp = Z @ V
p = np.exp(logp)
p = p / np.sum(p, axis=-1, keepdims=True)
return p
@lru_cache(maxsize=None)
def get_V(self, k):
def helmert_matrix(k):
"""
Returns the (k x k) Helmert matrix.
"""
H = np.zeros((k, k))
for i in range(1, k):
H[i, :i] = 1
H[i, i] = -(i)
H[i] = H[i] / np.sqrt(i * (i + 1))
# row 0 stays zeros; will be discarded
return H
def ilr_basis(k):
"""
Constructs an orthonormal ILR basis using the Helmert submatrix.
Output shape: (k-1, k)
"""
H = helmert_matrix(k)
V = H[1:, :] # remove first row of zeros
return V
return ilr_basis(k)
class ConfidenceEllipseSimplex(ConfidenceRegionABC):
"""
Instantiates a Confidence Ellipse in the probability simplex.