adding plots to prior test
This commit is contained in:
parent
300b8e6423
commit
93c33fe237
|
|
@ -36,6 +36,9 @@ def methods():
|
||||||
|
|
||||||
yield 'BayesianACC', ACC(LR()), acc_hyper, lambda hyper: BayesianCC(LR(), mcmc_seed=0, prior='uniform')
|
yield 'BayesianACC', ACC(LR()), acc_hyper, lambda hyper: BayesianCC(LR(), mcmc_seed=0, prior='uniform')
|
||||||
yield f'BaKDE-Ait', KDEyCLR(LR()), kdey_hyper_clr, lambda hyper: BayesianKDEy(kernel='aitchison', mcmc_seed=0, engine='numpyro', temperature=None, prior='uniform', **hyper)
|
yield f'BaKDE-Ait', KDEyCLR(LR()), kdey_hyper_clr, lambda hyper: BayesianKDEy(kernel='aitchison', mcmc_seed=0, engine='numpyro', temperature=None, prior='uniform', **hyper)
|
||||||
|
yield f'BaKDE-Ait-T2', KDEyCLR(LR()), kdey_hyper_clr, lambda hyper: BayesianKDEy(kernel='aitchison', mcmc_seed=0,
|
||||||
|
engine='numpyro', temperature=2.,
|
||||||
|
prior='uniform', **hyper)
|
||||||
|
|
||||||
|
|
||||||
def run_test(test, alpha_test, alpha_train, concentration, prior_type, bay_quant, train_prev, dataset_name, method_name, results):
|
def run_test(test, alpha_test, alpha_train, concentration, prior_type, bay_quant, train_prev, dataset_name, method_name, results):
|
||||||
|
|
@ -117,7 +120,7 @@ def concat_reports(reports):
|
||||||
return df
|
return df
|
||||||
|
|
||||||
|
|
||||||
def plot_results(df):
|
def error_vs_concentration_plot(df, err='ae'):
|
||||||
import seaborn as sns
|
import seaborn as sns
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
|
@ -131,7 +134,7 @@ def plot_results(df):
|
||||||
sns.lineplot(
|
sns.lineplot(
|
||||||
data=sub,
|
data=sub,
|
||||||
x='concentration',
|
x='concentration',
|
||||||
y='ae',
|
y=err,
|
||||||
hue='method_name',
|
hue='method_name',
|
||||||
marker='o',
|
marker='o',
|
||||||
errorbar='se', # o 'sd'
|
errorbar='se', # o 'sd'
|
||||||
|
|
@ -141,7 +144,103 @@ def plot_results(df):
|
||||||
ax.set_xscale('log')
|
ax.set_xscale('log')
|
||||||
ax.set_title(f'Prior: {prior}')
|
ax.set_title(f'Prior: {prior}')
|
||||||
ax.set_xlabel('Concentration')
|
ax.set_xlabel('Concentration')
|
||||||
ax.set_ylabel('MAE')
|
ax.set_ylabel('M'+err.upper())
|
||||||
|
|
||||||
|
plt.tight_layout()
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
|
||||||
|
def coverage_vs_concentration_plot(df):
|
||||||
|
import seaborn as sns
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
fig, axes = plt.subplots(1, 2, figsize=(10, 4), sharey=True)
|
||||||
|
|
||||||
|
for ax, prior in zip(axes, ['informative', 'wrong']):
|
||||||
|
sub = df[df['prior-type'] == prior]
|
||||||
|
|
||||||
|
sns.lineplot(
|
||||||
|
data=sub,
|
||||||
|
x='concentration',
|
||||||
|
y='coverage',
|
||||||
|
hue='method_name',
|
||||||
|
marker='o',
|
||||||
|
errorbar='se',
|
||||||
|
ax=ax
|
||||||
|
)
|
||||||
|
|
||||||
|
ax.set_xscale('log')
|
||||||
|
ax.set_ylim(0, 1.05)
|
||||||
|
ax.set_title(f'Prior: {prior}')
|
||||||
|
ax.set_xlabel('Concentration')
|
||||||
|
ax.set_ylabel('Coverage')
|
||||||
|
|
||||||
|
plt.tight_layout()
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
|
||||||
|
def amplitude_vs_concentration_plot(df):
|
||||||
|
import seaborn as sns
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
fig, axes = plt.subplots(1, 2, figsize=(10, 4), sharey=True)
|
||||||
|
|
||||||
|
for ax, prior in zip(axes, ['informative', 'wrong']):
|
||||||
|
sub = df[df['prior-type'] == prior]
|
||||||
|
|
||||||
|
sns.lineplot(
|
||||||
|
data=sub,
|
||||||
|
x='concentration',
|
||||||
|
y='amplitude',
|
||||||
|
hue='method_name',
|
||||||
|
marker='o',
|
||||||
|
errorbar='se',
|
||||||
|
ax=ax
|
||||||
|
)
|
||||||
|
|
||||||
|
ax.set_xscale('log')
|
||||||
|
ax.set_title(f'Prior: {prior}')
|
||||||
|
ax.set_xlabel('Concentration')
|
||||||
|
ax.set_ylabel('Amplitude')
|
||||||
|
|
||||||
|
plt.tight_layout()
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
|
||||||
|
def coverage_vs_amplitude_plot(df):
|
||||||
|
import seaborn as sns
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
agg = (
|
||||||
|
df
|
||||||
|
.groupby(['prior-type', 'method_name', 'concentration'])
|
||||||
|
.agg(
|
||||||
|
coverage=('coverage', 'mean'),
|
||||||
|
amplitude=('amplitude', 'mean')
|
||||||
|
)
|
||||||
|
.reset_index()
|
||||||
|
)
|
||||||
|
|
||||||
|
fig, axes = plt.subplots(1, 2, figsize=(10, 4), sharey=True)
|
||||||
|
|
||||||
|
for ax, prior in zip(axes, ['informative', 'wrong']):
|
||||||
|
sub = agg[agg['prior-type'] == prior]
|
||||||
|
|
||||||
|
sns.scatterplot(
|
||||||
|
data=sub,
|
||||||
|
x='amplitude',
|
||||||
|
y='coverage',
|
||||||
|
hue='method_name',
|
||||||
|
style='concentration',
|
||||||
|
s=80,
|
||||||
|
ax=ax
|
||||||
|
)
|
||||||
|
|
||||||
|
ax.set_ylim(0, 1.05)
|
||||||
|
ax.set_title(f'Prior: {prior}')
|
||||||
|
ax.set_xlabel('Amplitude')
|
||||||
|
ax.set_ylabel('Coverage')
|
||||||
|
ax.axhline(0.95, linestyle='--', color='gray', alpha=0.7)
|
||||||
|
|
||||||
plt.tight_layout()
|
plt.tight_layout()
|
||||||
plt.show()
|
plt.show()
|
||||||
|
|
@ -153,7 +252,7 @@ if __name__ == '__main__':
|
||||||
print(f'selected datasets={selected}')
|
print(f'selected datasets={selected}')
|
||||||
qp.environ['SAMPLE_SIZE'] = multiclass['sample_size']
|
qp.environ['SAMPLE_SIZE'] = multiclass['sample_size']
|
||||||
reports = []
|
reports = []
|
||||||
for data_name in selected[:2]:
|
for data_name in selected:
|
||||||
data = multiclass['fetch_fn'](data_name)
|
data = multiclass['fetch_fn'](data_name)
|
||||||
for method_name, surrogate_quant, hyper_params, bay_constructor in methods():
|
for method_name, surrogate_quant, hyper_params, bay_constructor in methods():
|
||||||
result_path = experiment_path(result_dir, data_name, method_name)
|
result_path = experiment_path(result_dir, data_name, method_name)
|
||||||
|
|
@ -168,6 +267,11 @@ if __name__ == '__main__':
|
||||||
# concat all reports as a dataframe
|
# concat all reports as a dataframe
|
||||||
df = concat_reports(reports)
|
df = concat_reports(reports)
|
||||||
|
|
||||||
plot_results(df)
|
for data_name in selected:
|
||||||
|
print(data_name)
|
||||||
|
df_ = df[df['dataset']==data_name]
|
||||||
|
error_vs_concentration_plot(df_)
|
||||||
|
coverage_vs_concentration_plot(df_)
|
||||||
|
amplitude_vs_concentration_plot(df_)
|
||||||
|
coverage_vs_amplitude_plot(df_)
|
||||||
|
|
||||||
print('ONLY TWO DATASETS')
|
|
||||||
Loading…
Reference in New Issue