adding plots to prior test

This commit is contained in:
Alejandro Moreo Fernandez 2026-01-13 17:29:40 +01:00
parent 300b8e6423
commit 93c33fe237
1 changed files with 110 additions and 6 deletions

View File

@ -36,6 +36,9 @@ def methods():
yield 'BayesianACC', ACC(LR()), acc_hyper, lambda hyper: BayesianCC(LR(), mcmc_seed=0, prior='uniform')
yield f'BaKDE-Ait', KDEyCLR(LR()), kdey_hyper_clr, lambda hyper: BayesianKDEy(kernel='aitchison', mcmc_seed=0, engine='numpyro', temperature=None, prior='uniform', **hyper)
yield f'BaKDE-Ait-T2', KDEyCLR(LR()), kdey_hyper_clr, lambda hyper: BayesianKDEy(kernel='aitchison', mcmc_seed=0,
engine='numpyro', temperature=2.,
prior='uniform', **hyper)
def run_test(test, alpha_test, alpha_train, concentration, prior_type, bay_quant, train_prev, dataset_name, method_name, results):
@ -117,7 +120,7 @@ def concat_reports(reports):
return df
def plot_results(df):
def error_vs_concentration_plot(df, err='ae'):
import seaborn as sns
import matplotlib.pyplot as plt
@ -131,7 +134,7 @@ def plot_results(df):
sns.lineplot(
data=sub,
x='concentration',
y='ae',
y=err,
hue='method_name',
marker='o',
errorbar='se', # o 'sd'
@ -141,7 +144,103 @@ def plot_results(df):
ax.set_xscale('log')
ax.set_title(f'Prior: {prior}')
ax.set_xlabel('Concentration')
ax.set_ylabel('MAE')
ax.set_ylabel('M'+err.upper())
plt.tight_layout()
plt.show()
def coverage_vs_concentration_plot(df):
import seaborn as sns
import matplotlib.pyplot as plt
fig, axes = plt.subplots(1, 2, figsize=(10, 4), sharey=True)
for ax, prior in zip(axes, ['informative', 'wrong']):
sub = df[df['prior-type'] == prior]
sns.lineplot(
data=sub,
x='concentration',
y='coverage',
hue='method_name',
marker='o',
errorbar='se',
ax=ax
)
ax.set_xscale('log')
ax.set_ylim(0, 1.05)
ax.set_title(f'Prior: {prior}')
ax.set_xlabel('Concentration')
ax.set_ylabel('Coverage')
plt.tight_layout()
plt.show()
def amplitude_vs_concentration_plot(df):
import seaborn as sns
import matplotlib.pyplot as plt
fig, axes = plt.subplots(1, 2, figsize=(10, 4), sharey=True)
for ax, prior in zip(axes, ['informative', 'wrong']):
sub = df[df['prior-type'] == prior]
sns.lineplot(
data=sub,
x='concentration',
y='amplitude',
hue='method_name',
marker='o',
errorbar='se',
ax=ax
)
ax.set_xscale('log')
ax.set_title(f'Prior: {prior}')
ax.set_xlabel('Concentration')
ax.set_ylabel('Amplitude')
plt.tight_layout()
plt.show()
def coverage_vs_amplitude_plot(df):
import seaborn as sns
import matplotlib.pyplot as plt
agg = (
df
.groupby(['prior-type', 'method_name', 'concentration'])
.agg(
coverage=('coverage', 'mean'),
amplitude=('amplitude', 'mean')
)
.reset_index()
)
fig, axes = plt.subplots(1, 2, figsize=(10, 4), sharey=True)
for ax, prior in zip(axes, ['informative', 'wrong']):
sub = agg[agg['prior-type'] == prior]
sns.scatterplot(
data=sub,
x='amplitude',
y='coverage',
hue='method_name',
style='concentration',
s=80,
ax=ax
)
ax.set_ylim(0, 1.05)
ax.set_title(f'Prior: {prior}')
ax.set_xlabel('Amplitude')
ax.set_ylabel('Coverage')
ax.axhline(0.95, linestyle='--', color='gray', alpha=0.7)
plt.tight_layout()
plt.show()
@ -153,7 +252,7 @@ if __name__ == '__main__':
print(f'selected datasets={selected}')
qp.environ['SAMPLE_SIZE'] = multiclass['sample_size']
reports = []
for data_name in selected[:2]:
for data_name in selected:
data = multiclass['fetch_fn'](data_name)
for method_name, surrogate_quant, hyper_params, bay_constructor in methods():
result_path = experiment_path(result_dir, data_name, method_name)
@ -168,6 +267,11 @@ if __name__ == '__main__':
# concat all reports as a dataframe
df = concat_reports(reports)
plot_results(df)
for data_name in selected:
print(data_name)
df_ = df[df['dataset']==data_name]
error_vs_concentration_plot(df_)
coverage_vs_concentration_plot(df_)
amplitude_vs_concentration_plot(df_)
coverage_vs_amplitude_plot(df_)
print('ONLY TWO DATASETS')