From 93c33fe237e82f8c922bbcec53d8881bc6022cc7 Mon Sep 17 00:00:00 2001 From: Alejandro Moreo Date: Tue, 13 Jan 2026 17:29:40 +0100 Subject: [PATCH] adding plots to prior test --- BayesianKDEy/prior_effect.py | 116 +++++++++++++++++++++++++++++++++-- 1 file changed, 110 insertions(+), 6 deletions(-) diff --git a/BayesianKDEy/prior_effect.py b/BayesianKDEy/prior_effect.py index 4fa1f19..bec9632 100644 --- a/BayesianKDEy/prior_effect.py +++ b/BayesianKDEy/prior_effect.py @@ -36,6 +36,9 @@ def methods(): yield 'BayesianACC', ACC(LR()), acc_hyper, lambda hyper: BayesianCC(LR(), mcmc_seed=0, prior='uniform') yield f'BaKDE-Ait', KDEyCLR(LR()), kdey_hyper_clr, lambda hyper: BayesianKDEy(kernel='aitchison', mcmc_seed=0, engine='numpyro', temperature=None, prior='uniform', **hyper) + yield f'BaKDE-Ait-T2', KDEyCLR(LR()), kdey_hyper_clr, lambda hyper: BayesianKDEy(kernel='aitchison', mcmc_seed=0, + engine='numpyro', temperature=2., + prior='uniform', **hyper) def run_test(test, alpha_test, alpha_train, concentration, prior_type, bay_quant, train_prev, dataset_name, method_name, results): @@ -117,7 +120,7 @@ def concat_reports(reports): return df -def plot_results(df): +def error_vs_concentration_plot(df, err='ae'): import seaborn as sns import matplotlib.pyplot as plt @@ -131,7 +134,7 @@ def plot_results(df): sns.lineplot( data=sub, x='concentration', - y='ae', + y=err, hue='method_name', marker='o', errorbar='se', # o 'sd' @@ -141,7 +144,103 @@ def plot_results(df): ax.set_xscale('log') ax.set_title(f'Prior: {prior}') ax.set_xlabel('Concentration') - ax.set_ylabel('MAE') + ax.set_ylabel('M'+err.upper()) + + plt.tight_layout() + plt.show() + + +def coverage_vs_concentration_plot(df): + import seaborn as sns + import matplotlib.pyplot as plt + + fig, axes = plt.subplots(1, 2, figsize=(10, 4), sharey=True) + + for ax, prior in zip(axes, ['informative', 'wrong']): + sub = df[df['prior-type'] == prior] + + sns.lineplot( + data=sub, + x='concentration', + y='coverage', + hue='method_name', + marker='o', + errorbar='se', + ax=ax + ) + + ax.set_xscale('log') + ax.set_ylim(0, 1.05) + ax.set_title(f'Prior: {prior}') + ax.set_xlabel('Concentration') + ax.set_ylabel('Coverage') + + plt.tight_layout() + plt.show() + + +def amplitude_vs_concentration_plot(df): + import seaborn as sns + import matplotlib.pyplot as plt + + fig, axes = plt.subplots(1, 2, figsize=(10, 4), sharey=True) + + for ax, prior in zip(axes, ['informative', 'wrong']): + sub = df[df['prior-type'] == prior] + + sns.lineplot( + data=sub, + x='concentration', + y='amplitude', + hue='method_name', + marker='o', + errorbar='se', + ax=ax + ) + + ax.set_xscale('log') + ax.set_title(f'Prior: {prior}') + ax.set_xlabel('Concentration') + ax.set_ylabel('Amplitude') + + plt.tight_layout() + plt.show() + + +def coverage_vs_amplitude_plot(df): + import seaborn as sns + import matplotlib.pyplot as plt + + agg = ( + df + .groupby(['prior-type', 'method_name', 'concentration']) + .agg( + coverage=('coverage', 'mean'), + amplitude=('amplitude', 'mean') + ) + .reset_index() + ) + + fig, axes = plt.subplots(1, 2, figsize=(10, 4), sharey=True) + + for ax, prior in zip(axes, ['informative', 'wrong']): + sub = agg[agg['prior-type'] == prior] + + sns.scatterplot( + data=sub, + x='amplitude', + y='coverage', + hue='method_name', + style='concentration', + s=80, + ax=ax + ) + + ax.set_ylim(0, 1.05) + ax.set_title(f'Prior: {prior}') + ax.set_xlabel('Amplitude') + ax.set_ylabel('Coverage') + ax.axhline(0.95, linestyle='--', color='gray', alpha=0.7) plt.tight_layout() plt.show() @@ -153,7 +252,7 @@ if __name__ == '__main__': print(f'selected datasets={selected}') qp.environ['SAMPLE_SIZE'] = multiclass['sample_size'] reports = [] - for data_name in selected[:2]: + for data_name in selected: data = multiclass['fetch_fn'](data_name) for method_name, surrogate_quant, hyper_params, bay_constructor in methods(): result_path = experiment_path(result_dir, data_name, method_name) @@ -168,6 +267,11 @@ if __name__ == '__main__': # concat all reports as a dataframe df = concat_reports(reports) - plot_results(df) + for data_name in selected: + print(data_name) + df_ = df[df['dataset']==data_name] + error_vs_concentration_plot(df_) + coverage_vs_concentration_plot(df_) + amplitude_vs_concentration_plot(df_) + coverage_vs_amplitude_plot(df_) - print('ONLY TWO DATASETS') \ No newline at end of file