new plots

This commit is contained in:
Alejandro Moreo Fernandez 2025-12-10 13:11:49 +01:00
parent 1eebfbc709
commit 5c2554861c
2 changed files with 222 additions and 0 deletions

View File

@ -0,0 +1,126 @@
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import gaussian_kde
def plot_kde_background(ax, data, cmap="Blues", alpha=0.35, gridsize=200):
"""
data: array Nx2
"""
# KDE
kde = gaussian_kde(data.T)
# Grid for evaluation
x_min, x_max = data[:, 0].min() - 1, data[:, 0].max() + 1
y_min, y_max = data[:, 1].min() - 1, data[:, 1].max() + 1
X, Y = np.meshgrid(
np.linspace(x_min, x_max, gridsize),
np.linspace(y_min, y_max, gridsize)
)
Z = kde(np.vstack([X.ravel(), Y.ravel()])).reshape(X.shape)
# Draw background density
ax.contourf(X, Y, Z, levels=30, cmap=cmap, alpha=alpha)
# ======================================================
# Define 3 Gaussian sources in 2D
# ======================================================
# Means
mu1 = np.array([0, 0]) # negative
mu2 = np.array([3, 0]) # positive
mu3 = np.array([0, 3]) # positive
# Covariances
Sigma = np.array([[1, 0.2],
[0.2, 1]])
def sample_gaussian(mu, Sigma, n):
return np.random.multivariate_normal(mu, Sigma, n)
# ======================================================
# Generate datasets for the 4 scenarios
# ======================================================
density = 20
# ---------- Scenario 1: Baseline ----------
G1_1 = sample_gaussian(mu1, Sigma, 100*density)
G2_1 = sample_gaussian(mu2, Sigma, 100*density)
G3_1 = sample_gaussian(mu3, Sigma, 100*density)
# ---------- Scenario 2: Prior Probability Shift ----------
G1_2 = sample_gaussian(mu1, Sigma, 300*density)
G2_2 = sample_gaussian(mu2, Sigma, 50*density)
G3_2 = sample_gaussian(mu3, Sigma, 50*density)
# ---------- Scenario 3: Covariate Shift ----------
# same class proportions but G3 moves (X-shift)
mu3_shift = mu3 + np.array([1.5, 0])
G1_3 = sample_gaussian(mu1, Sigma, 100*density)
G2_3 = sample_gaussian(mu2, Sigma, 100*density)
G3_3 = sample_gaussian(mu3_shift, Sigma, 100*density) # shifted covariates
# ---------- Scenario 4: Concept Shift ----------
# same data as Scenario 1, but G3 becomes negative
G1_4 = G1_1
G2_4 = G2_1
G3_4 = G3_1 # but will be colored as negative
# ======================================================
# Plotting function for each scenario
# ======================================================
def plot_scenario(ax, G1, G2, G3, title, G3_negative=False):
# plot_kde_background(ax, G1, cmap="Reds", alpha=0.75)
# plot_kde_background(ax, G2, cmap="Blues", alpha=0.75)
# plot_kde_background(ax, G3, cmap="Greens", alpha=0.75)
ax.scatter(G1[:, 0], G1[:, 1], s=12, color='red', alpha=0.1, label='Negative ($\ominus$)')
ax.scatter(G2[:, 0], G2[:, 1], s=12, color='blue', alpha=0.1, label='Positive ($\oplus$)')
if G3_negative:
ax.scatter(G3[:, 0], G3[:, 1], s=12, color='red', alpha=0.1) #, label='Negative ($\ominus$)')
else:
ax.scatter(G3[:, 0], G3[:, 1], s=12, color='blue', alpha=0.1) #, label='Positive ($\oplus$)')
ax.set_title(title)
ax.set_xlabel("$x_1$")
ax.set_ylabel("$x_2$")
ax.set_xticks([])
ax.set_yticks([])
ax.grid(alpha=0.3)
# ======================================================
# Generate 2×2 grid of subplots
# ======================================================
fig, axes = plt.subplots(2, 2, figsize=(9, 9))
plot_scenario(axes[0, 0], G1_1, G2_1, G3_1,
"Training data")
plot_scenario(axes[0, 1], G1_2, G2_2, G3_2,
"Prior Probability Shift")
plot_scenario(axes[1, 0], G1_3, G2_3, G3_3,
"Covariate Shift",
G3_negative=False)
plot_scenario(axes[1, 1], G1_4, G2_4, G3_4,
"Concept Shift",
G3_negative=True)
# One global legend
handles, labels = axes[0, 0].get_legend_handles_labels()
fig.legend(handles, labels, loc='upper center', ncol=3, fontsize=12)
plt.tight_layout(rect=[0, 0, 1, 0.95])
# plt.show()
plt.savefig('dataset_shift_types.pdf')

View File

@ -0,0 +1,96 @@
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d.art3d import Poly3DCollection
from quapy.functional import uniform_prevalence_sampling
# Vertices of a regular tetrahedron
v1 = np.array([0, 0, 0])
v2 = np.array([1, 0, 0])
v3 = np.array([0.5, np.sqrt(3)/2, 0])
v4 = np.array([0.5, np.sqrt(3)/6, np.sqrt(6)/3])
vertices = np.array([v1, v2, v3, v4])
# Function to map (p1,p2,p3,p4) to 3D coordinates
def prob_to_xyz(p):
return p[0]*v1 + p[1]*v2 + p[2]*v3 + p[3]*v4
# --- Example: 5 random distributions inside the simplex
rand_probs = uniform_prevalence_sampling(n_classes=4, size=3000)
points_xyz = np.array([prob_to_xyz(p) for p in rand_probs])
# --- Plotting
fig = plt.figure(figsize=(8, 8))
ax = fig.add_subplot(111, projection='3d')
# Draw tetrahedron faces
faces = [
[v1, v2, v3],
[v1, v2, v4],
[v1, v3, v4],
[v2, v3, v4]
]
poly = Poly3DCollection(faces, alpha=0.15, edgecolor='k', facecolor=None)
# poly = Poly3DCollection(faces, alpha=0.15, edgecolor='k', facecolor=None)
ax.add_collection3d(poly)
edges = [
[v1, v2],
[v1, v3],
[v1, v4],
[v2, v3],
[v2, v4],
[v3, v4]
]
for edge in edges:
xs, ys, zs = zip(*edge)
ax.plot(xs, ys, zs, color='black', linewidth=1)
# Draw vertices
# ax.scatter(vertices[:,0], vertices[:,1], vertices[:,2], s=60, color='red')
# Labels
offset = 0.08
labels = ["$y_1$", "$y_2$", "$y_3$", "$y_4$"]
for i, v in enumerate(vertices):
direction = v / np.linalg.norm(v)
label_pos = v + offset * direction
ax.text(label_pos[0], label_pos[1], label_pos[2], labels[i], fontsize=14, color='black')
# Plot random points
ax.scatter(points_xyz[:,0], points_xyz[:,1], points_xyz[:,2], s=10, c='blue', alpha=0.2)
# Axes formatting
ax.set_xlabel("X")
ax.set_ylabel("Y")
ax.set_zlabel("Z")
ax.set_title("4-Class Probability Simplex (Tetrahedron)")
# ax.view_init(elev=65, azim=20)
ax.set_xticks([])
ax.set_yticks([])
ax.set_zticks([])
ax.set_xticklabels([])
ax.set_yticklabels([])
ax.set_zticklabels([])
ax.xaxis.set_ticks_position('none') # evita que dibuje marcas
ax.yaxis.set_ticks_position('none')
ax.zaxis.set_ticks_position('none')
ax.xaxis.pane.set_visible(False)
ax.yaxis.pane.set_visible(False)
ax.zaxis.pane.set_visible(False)
ax.set_xlabel('')
ax.set_ylabel('')
ax.set_zlabel('')
ax.grid(False)
plt.tight_layout()
plt.show()