QuaPy/BayesianKDEy/bandwidth_and_dimensionalit...

79 lines
2.1 KiB
Python

from sklearn.neighbors import KernelDensity
import quapy.functional as F
import numpy as np
import numpy as np
import matplotlib.pyplot as plt
from sklearn.neighbors import KernelDensity
import quapy.functional as F
# aitchison=True
aitchison=False
clr = F.CLRtransformation()
# h = 0.1
# dims = list(range(5, 100, 5))
dims = [10, 28, 100]
center_densities = []
vertex_densities = []
center_densities_scaled = []
vertex_densities_scaled = []
for n in dims:
h0 = 0.4
simplex_center = F.uniform_prevalence(n)
simplex_vertex = np.asarray([.9] + [.1/ (n - 1)] * (n - 1), dtype=float)
# KDE trained on a single point (the center)
kde = KernelDensity(bandwidth=h0)
X = simplex_center[None, :]
if aitchison:
X = clr(X)
kde.fit(X)
X = np.vstack([simplex_center, simplex_vertex])
if aitchison:
X = clr(X)
density = np.exp(kde.score_samples(X))
center_densities.append(density[0])
vertex_densities.append(density[1])
h1= h0 * np.sqrt(n / 2)
# KDE trained on a single point (the center)
kde = KernelDensity(bandwidth=h1)
X = simplex_center[None, :]
if aitchison:
X = clr(X)
kde.fit(X)
X = np.vstack([simplex_center, simplex_vertex])
if aitchison:
X = clr(X)
density = np.exp(kde.score_samples(X))
center_densities_scaled.append(density[0])
vertex_densities_scaled.append(density[1])
# Plot
plt.figure(figsize=(6*4, 4*4))
plt.plot(dims, center_densities, marker='o', label='Center of simplex')
plt.plot(dims, vertex_densities, marker='s', label='Vertex of simplex')
plt.plot(dims, center_densities_scaled, marker='o', label='Center of simplex (scaled)')
plt.plot(dims, vertex_densities_scaled, marker='s', label='Vertex of simplex (scaled)')
plt.xlabel('Number of classes (simplex dimension)')
# plt.ylim(min(center_densities+vertex_densities), max(center_densities+vertex_densities))
plt.ylabel('Kernel density')
plt.yscale('log') # crucial to see anything meaningful
plt.title(f'KDE density vs dimension (bandwidth = {h0}) in {"Simplex" if not aitchison else "ILR-space"}')
plt.legend()
plt.grid(alpha=0.3)
plt.tight_layout()
plt.show()