225 lines
7.2 KiB
Python
225 lines
7.2 KiB
Python
import os
|
|
import pickle
|
|
from pathlib import Path
|
|
|
|
import numpy as np
|
|
import matplotlib.pyplot as plt
|
|
from matplotlib.colors import ListedColormap
|
|
from scipy.stats import gaussian_kde
|
|
|
|
from method.confidence import (ConfidenceIntervals as CI,
|
|
ConfidenceEllipseSimplex as CE,
|
|
ConfidenceEllipseCLR as CLR,
|
|
ConfidenceEllipseILR as ILR)
|
|
|
|
|
|
|
|
def get_region_colormap(name="blue", alpha=0.40):
|
|
name = name.lower()
|
|
if name == "blue":
|
|
base = (76/255, 114/255, 176/255)
|
|
elif name == "orange":
|
|
base = (221/255, 132/255, 82/255)
|
|
elif name == "violet":
|
|
base = (129/255, 114/255, 178/255)
|
|
else:
|
|
raise ValueError(f"Unknown palette name: {name}")
|
|
|
|
cmap = ListedColormap([
|
|
(1, 1, 1, 0), # 0: transparent white
|
|
(base[0], base[1], base[2], alpha) # 1: color
|
|
])
|
|
|
|
return cmap
|
|
|
|
|
|
def plot_prev_points(prevs=None, true_prev=None, point_estim=None, train_prev=None, show_mean=True, show_legend=True,
|
|
region=None,
|
|
region_resolution=1000,
|
|
confine_region_in_simplex=False,
|
|
color='blue',
|
|
save_path=None):
|
|
|
|
plt.rcParams.update({
|
|
'font.size': 10, # tamaño base de todo el texto
|
|
'axes.titlesize': 12, # título del eje
|
|
'axes.labelsize': 10, # etiquetas de ejes
|
|
'xtick.labelsize': 8, # etiquetas de ticks
|
|
'ytick.labelsize': 8,
|
|
'legend.fontsize': 9, # leyenda
|
|
})
|
|
|
|
def cartesian(p):
|
|
dim = p.shape[-1]
|
|
p = p.reshape(-1,dim)
|
|
x = p[:, 1] + p[:, 2] * 0.5
|
|
y = p[:, 2] * np.sqrt(3) / 2
|
|
return x, y
|
|
|
|
def barycentric_from_xy(x, y):
|
|
"""
|
|
Given cartesian (x,y) in simplex returns baricentric coordinates (p1,p2,p3).
|
|
"""
|
|
p3 = 2 * y / np.sqrt(3)
|
|
p2 = x - 0.5 * p3
|
|
p1 = 1 - p2 - p3
|
|
return np.stack([p1, p2, p3], axis=-1)
|
|
|
|
# simplex coordinates
|
|
v1 = np.array([0, 0])
|
|
v2 = np.array([1, 0])
|
|
v3 = np.array([0.5, np.sqrt(3)/2])
|
|
|
|
# Plot
|
|
fig, ax = plt.subplots(figsize=(6, 6))
|
|
|
|
if region is not None:
|
|
if callable(region):
|
|
region_list = [("region", region)]
|
|
else:
|
|
region_list = region # lista de (name, fn)
|
|
|
|
if region is not None:
|
|
# rectangular mesh
|
|
x_min, x_max = -0.2, 1.2
|
|
y_min, y_max = -0.2, np.sqrt(3) / 2 + 0.2
|
|
|
|
xs = np.linspace(x_min, x_max, region_resolution)
|
|
ys = np.linspace(y_min, y_max, region_resolution)
|
|
grid_x, grid_y = np.meshgrid(xs, ys)
|
|
|
|
# barycentric
|
|
pts_bary = barycentric_from_xy(grid_x, grid_y)
|
|
|
|
# mask within simplex
|
|
if confine_region_in_simplex:
|
|
in_simplex = np.all(pts_bary >= 0, axis=-1)
|
|
else:
|
|
in_simplex = np.full(shape=(region_resolution, region_resolution), fill_value=True, dtype=bool)
|
|
|
|
# --- Colormap 0 → blanco, 1 → rojo semitransparente ---
|
|
|
|
# iterar sobre todas las regiones
|
|
for (rname, rfun) in region_list:
|
|
mask = np.zeros_like(in_simplex, dtype=float)
|
|
valid_pts = pts_bary[in_simplex]
|
|
mask_vals = np.array([float(rfun(p)) for p in valid_pts])
|
|
mask[in_simplex] = mask_vals
|
|
|
|
ax.pcolormesh(
|
|
xs, ys, mask,
|
|
shading='auto',
|
|
cmap=get_region_colormap(color),
|
|
alpha=0.3,
|
|
)
|
|
|
|
ax.scatter(*cartesian(prevs), s=15, alpha=0.5, edgecolors='none', label='samples', color='black', linewidth=0.5)
|
|
if show_mean:
|
|
ax.scatter(*cartesian(prevs.mean(axis=0)), s=10, alpha=1, label='sample-mean', edgecolors='black')
|
|
if train_prev is not None:
|
|
ax.scatter(*cartesian(true_prev), s=10, alpha=1, label='true-prev', edgecolors='black')
|
|
if point_estim is not None:
|
|
ax.scatter(*cartesian(point_estim), s=10, alpha=1, label='KDEy-estim', edgecolors='black')
|
|
if train_prev is not None:
|
|
ax.scatter(*cartesian(train_prev), s=10, alpha=1, label='train-prev', edgecolors='black')
|
|
|
|
# edges
|
|
triangle = np.array([v1, v2, v3, v1])
|
|
ax.plot(triangle[:, 0], triangle[:, 1], color='black')
|
|
|
|
# vertex labels
|
|
ax.text(-0.05, -0.05, "Y=1", ha='right', va='top')
|
|
ax.text(1.05, -0.05, "Y=2", ha='left', va='top')
|
|
ax.text(0.5, np.sqrt(3)/2 + 0.05, "Y=3", ha='center', va='bottom')
|
|
|
|
ax.set_aspect('equal')
|
|
ax.axis('off')
|
|
if show_legend:
|
|
plt.legend(
|
|
loc='center left',
|
|
bbox_to_anchor=(1.05, 0.5),
|
|
)
|
|
plt.tight_layout()
|
|
if save_path is None:
|
|
plt.show()
|
|
else:
|
|
os.makedirs(Path(save_path).parent, exist_ok=True)
|
|
plt.savefig(save_path)
|
|
|
|
|
|
def plot_prev_points_matplot(points):
|
|
|
|
# project 2D
|
|
v1 = np.array([0, 0])
|
|
v2 = np.array([1, 0])
|
|
v3 = np.array([0.5, np.sqrt(3) / 2])
|
|
x = points[:, 1] + points[:, 2] * 0.5
|
|
y = points[:, 2] * np.sqrt(3) / 2
|
|
|
|
# kde
|
|
xy = np.vstack([x, y])
|
|
kde = gaussian_kde(xy, bw_method=0.25)
|
|
xmin, xmax = 0, 1
|
|
ymin, ymax = 0, np.sqrt(3) / 2
|
|
|
|
# grid
|
|
xx, yy = np.mgrid[xmin:xmax:200j, ymin:ymax:200j]
|
|
positions = np.vstack([xx.ravel(), yy.ravel()])
|
|
zz = np.reshape(kde(positions).T, xx.shape)
|
|
|
|
# mask points in simplex
|
|
def in_triangle(x, y):
|
|
return (y >= 0) & (y <= np.sqrt(3) * np.minimum(x, 1 - x))
|
|
|
|
mask = in_triangle(xx, yy)
|
|
zz_masked = np.ma.array(zz, mask=~mask)
|
|
|
|
# plot
|
|
fig, ax = plt.subplots(figsize=(6, 6))
|
|
ax.imshow(
|
|
np.rot90(zz_masked),
|
|
cmap=plt.cm.viridis,
|
|
extent=[xmin, xmax, ymin, ymax],
|
|
alpha=0.8,
|
|
)
|
|
|
|
# Bordes del triángulo
|
|
triangle = np.array([v1, v2, v3, v1])
|
|
ax.plot(triangle[:, 0], triangle[:, 1], color='black', lw=2)
|
|
|
|
# Puntos (opcional)
|
|
ax.scatter(x, y, s=5, c='white', alpha=0.3)
|
|
|
|
# Etiquetas
|
|
ax.text(-0.05, -0.05, "A (1,0,0)", ha='right', va='top')
|
|
ax.text(1.05, -0.05, "B (0,1,0)", ha='left', va='top')
|
|
ax.text(0.5, np.sqrt(3) / 2 + 0.05, "C (0,0,1)", ha='center', va='bottom')
|
|
|
|
ax.set_aspect('equal')
|
|
ax.axis('off')
|
|
plt.show()
|
|
|
|
if __name__ == '__main__':
|
|
np.random.seed(1)
|
|
|
|
n = 1000
|
|
alpha = [3,5,10]
|
|
# alpha = [10,1,1]
|
|
prevs = np.random.dirichlet(alpha, size=n)
|
|
|
|
def regions():
|
|
confs = [0.99, 0.95, 0.90]
|
|
yield 'CI', [(f'{int(c*100)}%', CI(prevs, confidence_level=c).coverage) for c in confs]
|
|
yield 'CI-b', [(f'{int(c * 100)}%', CI(prevs, confidence_level=c, bonferroni_correction=True).coverage) for c in confs]
|
|
yield 'CE', [(f'{int(c*100)}%', CE(prevs, confidence_level=c).coverage) for c in confs]
|
|
yield 'CLR', [(f'{int(c*100)}%', CLR(prevs, confidence_level=c).coverage) for c in confs]
|
|
yield 'ILR', [(f'{int(c*100)}%', ILR(prevs, confidence_level=c).coverage) for c in confs]
|
|
|
|
resolution = 1000
|
|
alpha_str = ','.join([f'{str(i)}' for i in alpha])
|
|
for crname, cr in regions():
|
|
plot_prev_points(prevs, show_mean=True, show_legend=False, region=cr, region_resolution=resolution,
|
|
color='blue',
|
|
save_path=f'./plots/simplex_{crname}_alpha{alpha_str}_res{resolution}.png')
|
|
|