import os import pickle from pathlib import Path import numpy as np import matplotlib.pyplot as plt from matplotlib.colors import ListedColormap from scipy.stats import gaussian_kde from method.confidence import ConfidenceIntervals, ConfidenceEllipseSimplex, ConfidenceEllipseCLR, ConfidenceEllipseILR def plot_prev_points(prevs=None, true_prev=None, point_estim=None, train_prev=None, show_mean=True, show_legend=True, region=None, region_resolution=1000, confine_region_in_simplex=False, save_path=None): plt.rcParams.update({ 'font.size': 10, # tamaño base de todo el texto 'axes.titlesize': 12, # título del eje 'axes.labelsize': 10, # etiquetas de ejes 'xtick.labelsize': 8, # etiquetas de ticks 'ytick.labelsize': 8, 'legend.fontsize': 9, # leyenda }) def cartesian(p): dim = p.shape[-1] p = p.reshape(-1,dim) x = p[:, 1] + p[:, 2] * 0.5 y = p[:, 2] * np.sqrt(3) / 2 return x, y def barycentric_from_xy(x, y): """ Given cartesian (x,y) in simplex returns baricentric coordinates (p1,p2,p3). """ p3 = 2 * y / np.sqrt(3) p2 = x - 0.5 * p3 p1 = 1 - p2 - p3 return np.stack([p1, p2, p3], axis=-1) # simplex coordinates v1 = np.array([0, 0]) v2 = np.array([1, 0]) v3 = np.array([0.5, np.sqrt(3)/2]) # Plot fig, ax = plt.subplots(figsize=(6, 6)) if region is not None: # rectangular mesh xs = np.linspace(0, 1, region_resolution) ys = np.linspace(0, np.sqrt(3)/2, region_resolution) grid_x, grid_y = np.meshgrid(xs, ys) # 2 barycentric pts_bary = barycentric_from_xy(grid_x, grid_y) # mask within simplex if confine_region_in_simplex: in_simplex = np.all(pts_bary >= 0, axis=-1) else: in_simplex = np.full(shape=(region_resolution, region_resolution), fill_value=True, dtype=bool) # evaluar la región solo en puntos válidos mask = np.zeros_like(in_simplex, dtype=float) valid_pts = pts_bary[in_simplex] mask_vals = np.array([float(region(p)) for p in valid_pts]) mask[in_simplex] = mask_vals # pintar el fondo white_and_color = ListedColormap([ (1, 1, 1, 1), # color for value 0 (0.7, .0, .0, .5) # color for value 1 ]) ax.pcolormesh( xs, ys, mask, shading='auto', cmap=white_and_color, alpha=0.5 ) ax.scatter(*cartesian(prevs), s=15, alpha=0.5, edgecolors='none', label='samples') if show_mean: ax.scatter(*cartesian(prevs.mean(axis=0)), s=10, alpha=1, label='sample-mean', edgecolors='black') if train_prev is not None: ax.scatter(*cartesian(true_prev), s=10, alpha=1, label='true-prev', edgecolors='black') if point_estim is not None: ax.scatter(*cartesian(point_estim), s=10, alpha=1, label='KDEy-estim', edgecolors='black') if train_prev is not None: ax.scatter(*cartesian(train_prev), s=10, alpha=1, label='train-prev', edgecolors='black') # edges triangle = np.array([v1, v2, v3, v1]) ax.plot(triangle[:, 0], triangle[:, 1], color='black') # vertex labels ax.text(-0.05, -0.05, "Y=1", ha='right', va='top') ax.text(1.05, -0.05, "Y=2", ha='left', va='top') ax.text(0.5, np.sqrt(3)/2 + 0.05, "Y=3", ha='center', va='bottom') ax.set_aspect('equal') ax.axis('off') if show_legend: plt.legend( loc='center left', bbox_to_anchor=(1.05, 0.5), ) plt.tight_layout() if save_path is None: plt.show() else: os.makedirs(Path(save_path).parent, exist_ok=True) plt.savefig(save_path) def plot_prev_points_matplot(points): # project 2D v1 = np.array([0, 0]) v2 = np.array([1, 0]) v3 = np.array([0.5, np.sqrt(3) / 2]) x = points[:, 1] + points[:, 2] * 0.5 y = points[:, 2] * np.sqrt(3) / 2 # kde xy = np.vstack([x, y]) kde = gaussian_kde(xy, bw_method=0.25) xmin, xmax = 0, 1 ymin, ymax = 0, np.sqrt(3) / 2 # grid xx, yy = np.mgrid[xmin:xmax:200j, ymin:ymax:200j] positions = np.vstack([xx.ravel(), yy.ravel()]) zz = np.reshape(kde(positions).T, xx.shape) # mask points in simplex def in_triangle(x, y): return (y >= 0) & (y <= np.sqrt(3) * np.minimum(x, 1 - x)) mask = in_triangle(xx, yy) zz_masked = np.ma.array(zz, mask=~mask) # plot fig, ax = plt.subplots(figsize=(6, 6)) ax.imshow( np.rot90(zz_masked), cmap=plt.cm.viridis, extent=[xmin, xmax, ymin, ymax], alpha=0.8, ) # Bordes del triángulo triangle = np.array([v1, v2, v3, v1]) ax.plot(triangle[:, 0], triangle[:, 1], color='black', lw=2) # Puntos (opcional) ax.scatter(x, y, s=5, c='white', alpha=0.3) # Etiquetas ax.text(-0.05, -0.05, "A (1,0,0)", ha='right', va='top') ax.text(1.05, -0.05, "B (0,1,0)", ha='left', va='top') ax.text(0.5, np.sqrt(3) / 2 + 0.05, "C (0,0,1)", ha='center', va='bottom') ax.set_aspect('equal') ax.axis('off') plt.show() if __name__ == '__main__': np.random.seed(1) n = 1000 alpha = [3,5,10] prevs = np.random.dirichlet(alpha, size=n) def regions(): yield 'CI', ConfidenceIntervals(prevs) yield 'CE', ConfidenceEllipseSimplex(prevs) yield 'CLR', ConfidenceEllipseCLR(prevs) yield 'ILR', ConfidenceEllipseILR(prevs) resolution = 100 alpha_str = ','.join([f'{str(i)}' for i in alpha]) for crname, cr in regions(): plot_prev_points(prevs, show_mean=True, show_legend=False, region=cr.coverage, region_resolution=resolution, save_path=f'./plots/simplex_{crname}_alpha{alpha_str}_res{resolution}.png')