adding plots to prior test
This commit is contained in:
parent
300b8e6423
commit
93c33fe237
|
|
@ -36,6 +36,9 @@ def methods():
|
|||
|
||||
yield 'BayesianACC', ACC(LR()), acc_hyper, lambda hyper: BayesianCC(LR(), mcmc_seed=0, prior='uniform')
|
||||
yield f'BaKDE-Ait', KDEyCLR(LR()), kdey_hyper_clr, lambda hyper: BayesianKDEy(kernel='aitchison', mcmc_seed=0, engine='numpyro', temperature=None, prior='uniform', **hyper)
|
||||
yield f'BaKDE-Ait-T2', KDEyCLR(LR()), kdey_hyper_clr, lambda hyper: BayesianKDEy(kernel='aitchison', mcmc_seed=0,
|
||||
engine='numpyro', temperature=2.,
|
||||
prior='uniform', **hyper)
|
||||
|
||||
|
||||
def run_test(test, alpha_test, alpha_train, concentration, prior_type, bay_quant, train_prev, dataset_name, method_name, results):
|
||||
|
|
@ -117,7 +120,7 @@ def concat_reports(reports):
|
|||
return df
|
||||
|
||||
|
||||
def plot_results(df):
|
||||
def error_vs_concentration_plot(df, err='ae'):
|
||||
import seaborn as sns
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
|
|
@ -131,7 +134,7 @@ def plot_results(df):
|
|||
sns.lineplot(
|
||||
data=sub,
|
||||
x='concentration',
|
||||
y='ae',
|
||||
y=err,
|
||||
hue='method_name',
|
||||
marker='o',
|
||||
errorbar='se', # o 'sd'
|
||||
|
|
@ -141,7 +144,103 @@ def plot_results(df):
|
|||
ax.set_xscale('log')
|
||||
ax.set_title(f'Prior: {prior}')
|
||||
ax.set_xlabel('Concentration')
|
||||
ax.set_ylabel('MAE')
|
||||
ax.set_ylabel('M'+err.upper())
|
||||
|
||||
plt.tight_layout()
|
||||
plt.show()
|
||||
|
||||
|
||||
def coverage_vs_concentration_plot(df):
|
||||
import seaborn as sns
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
fig, axes = plt.subplots(1, 2, figsize=(10, 4), sharey=True)
|
||||
|
||||
for ax, prior in zip(axes, ['informative', 'wrong']):
|
||||
sub = df[df['prior-type'] == prior]
|
||||
|
||||
sns.lineplot(
|
||||
data=sub,
|
||||
x='concentration',
|
||||
y='coverage',
|
||||
hue='method_name',
|
||||
marker='o',
|
||||
errorbar='se',
|
||||
ax=ax
|
||||
)
|
||||
|
||||
ax.set_xscale('log')
|
||||
ax.set_ylim(0, 1.05)
|
||||
ax.set_title(f'Prior: {prior}')
|
||||
ax.set_xlabel('Concentration')
|
||||
ax.set_ylabel('Coverage')
|
||||
|
||||
plt.tight_layout()
|
||||
plt.show()
|
||||
|
||||
|
||||
def amplitude_vs_concentration_plot(df):
|
||||
import seaborn as sns
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
fig, axes = plt.subplots(1, 2, figsize=(10, 4), sharey=True)
|
||||
|
||||
for ax, prior in zip(axes, ['informative', 'wrong']):
|
||||
sub = df[df['prior-type'] == prior]
|
||||
|
||||
sns.lineplot(
|
||||
data=sub,
|
||||
x='concentration',
|
||||
y='amplitude',
|
||||
hue='method_name',
|
||||
marker='o',
|
||||
errorbar='se',
|
||||
ax=ax
|
||||
)
|
||||
|
||||
ax.set_xscale('log')
|
||||
ax.set_title(f'Prior: {prior}')
|
||||
ax.set_xlabel('Concentration')
|
||||
ax.set_ylabel('Amplitude')
|
||||
|
||||
plt.tight_layout()
|
||||
plt.show()
|
||||
|
||||
|
||||
def coverage_vs_amplitude_plot(df):
|
||||
import seaborn as sns
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
agg = (
|
||||
df
|
||||
.groupby(['prior-type', 'method_name', 'concentration'])
|
||||
.agg(
|
||||
coverage=('coverage', 'mean'),
|
||||
amplitude=('amplitude', 'mean')
|
||||
)
|
||||
.reset_index()
|
||||
)
|
||||
|
||||
fig, axes = plt.subplots(1, 2, figsize=(10, 4), sharey=True)
|
||||
|
||||
for ax, prior in zip(axes, ['informative', 'wrong']):
|
||||
sub = agg[agg['prior-type'] == prior]
|
||||
|
||||
sns.scatterplot(
|
||||
data=sub,
|
||||
x='amplitude',
|
||||
y='coverage',
|
||||
hue='method_name',
|
||||
style='concentration',
|
||||
s=80,
|
||||
ax=ax
|
||||
)
|
||||
|
||||
ax.set_ylim(0, 1.05)
|
||||
ax.set_title(f'Prior: {prior}')
|
||||
ax.set_xlabel('Amplitude')
|
||||
ax.set_ylabel('Coverage')
|
||||
ax.axhline(0.95, linestyle='--', color='gray', alpha=0.7)
|
||||
|
||||
plt.tight_layout()
|
||||
plt.show()
|
||||
|
|
@ -153,7 +252,7 @@ if __name__ == '__main__':
|
|||
print(f'selected datasets={selected}')
|
||||
qp.environ['SAMPLE_SIZE'] = multiclass['sample_size']
|
||||
reports = []
|
||||
for data_name in selected[:2]:
|
||||
for data_name in selected:
|
||||
data = multiclass['fetch_fn'](data_name)
|
||||
for method_name, surrogate_quant, hyper_params, bay_constructor in methods():
|
||||
result_path = experiment_path(result_dir, data_name, method_name)
|
||||
|
|
@ -168,6 +267,11 @@ if __name__ == '__main__':
|
|||
# concat all reports as a dataframe
|
||||
df = concat_reports(reports)
|
||||
|
||||
plot_results(df)
|
||||
for data_name in selected:
|
||||
print(data_name)
|
||||
df_ = df[df['dataset']==data_name]
|
||||
error_vs_concentration_plot(df_)
|
||||
coverage_vs_concentration_plot(df_)
|
||||
amplitude_vs_concentration_plot(df_)
|
||||
coverage_vs_amplitude_plot(df_)
|
||||
|
||||
print('ONLY TWO DATASETS')
|
||||
Loading…
Reference in New Issue