reducing strength for antagonic prior

This commit is contained in:
Alejandro Moreo Fernandez 2026-01-13 18:19:09 +01:00
parent 93c33fe237
commit a511f577c9
2 changed files with 43 additions and 17 deletions

View File

@ -4,11 +4,17 @@
- analyze across shift
- add Bayesian EM:
- https://github.com/ChangkunYe/MAPLS/blob/main/label_shift/mapls.py
- take this opportunity to add RLLS: https://github.com/ChangkunYe/MAPLS/blob/main/label_shift/rlls.py
- take this opportunity to add RLLS:
https://github.com/Angie-Liu/labelshift
https://github.com/ChangkunYe/MAPLS/blob/main/label_shift/rlls.py
- add CIFAR10 and MNIST? Maybe consider also previously tested types of shift (tweak-one-out, etc.)? from RLLS paper
- https://github.com/Angie-Liu/labelshift/tree/master
- https://github.com/Angie-Liu/labelshift/blob/master/cifar10_for_labelshift.py
- Note: MNIST is downloadable from https://archive.ics.uci.edu/dataset/683/mnist+database+of+handwritten+digits
- Seem to be some pretrained models in:
https://github.com/geifmany/cifar-vgg
https://github.com/EN10/KerasMNIST
https://github.com/tohinz/SVHN-Classifier
- consider prior knowledge in experiments:
- One scenario in which our prior is uninformative (i.e., uniform)
- One scenario in which our prior is wrong (e.g., alpha-prior = (3,2,1), protocol-prior=(1,1,5))

View File

@ -39,6 +39,9 @@ def methods():
yield f'BaKDE-Ait-T2', KDEyCLR(LR()), kdey_hyper_clr, lambda hyper: BayesianKDEy(kernel='aitchison', mcmc_seed=0,
engine='numpyro', temperature=2.,
prior='uniform', **hyper)
yield f'BaKDE-Ait-T1', KDEyCLR(LR()), kdey_hyper_clr, lambda hyper: BayesianKDEy(kernel='aitchison', mcmc_seed=0,
engine='numpyro', temperature=1.,
prior='uniform', **hyper)
def run_test(test, alpha_test, alpha_train, concentration, prior_type, bay_quant, train_prev, dataset_name, method_name, results):
@ -104,7 +107,7 @@ def experiment(dataset: Dataset,
run_test(test, alpha_test_informative, alpha_train, concentration, prior_type, bay_quant, train_prev, dataset_name, method_name, results)
# informative prior
alpha_test_wrong = antagonistic_prevalence(train_prev, strength=1) * concentration
alpha_test_wrong = antagonistic_prevalence(train_prev, strength=0.5) * concentration
prior_type = 'wrong'
run_test(test, alpha_test_wrong, alpha_train, concentration, prior_type, bay_quant, train_prev, dataset_name, method_name, results)
@ -120,7 +123,7 @@ def concat_reports(reports):
return df
def error_vs_concentration_plot(df, err='ae'):
def error_vs_concentration_plot(df, err='ae', save_path=None):
import seaborn as sns
import matplotlib.pyplot as plt
@ -147,10 +150,14 @@ def error_vs_concentration_plot(df, err='ae'):
ax.set_ylabel('M'+err.upper())
plt.tight_layout()
plt.show()
if save_path is None:
plt.show()
else:
os.makedirs(Path(save_path).parent, exist_ok=True)
plt.savefig(save_path)
def coverage_vs_concentration_plot(df):
def coverage_vs_concentration_plot(df, save_path=None):
import seaborn as sns
import matplotlib.pyplot as plt
@ -176,10 +183,14 @@ def coverage_vs_concentration_plot(df):
ax.set_ylabel('Coverage')
plt.tight_layout()
plt.show()
if save_path is None:
plt.show()
else:
os.makedirs(Path(save_path).parent, exist_ok=True)
plt.savefig(save_path)
def amplitude_vs_concentration_plot(df):
def amplitude_vs_concentration_plot(df, save_path=None):
import seaborn as sns
import matplotlib.pyplot as plt
@ -204,10 +215,14 @@ def amplitude_vs_concentration_plot(df):
ax.set_ylabel('Amplitude')
plt.tight_layout()
plt.show()
if save_path is None:
plt.show()
else:
os.makedirs(Path(save_path).parent, exist_ok=True)
plt.savefig(save_path)
def coverage_vs_amplitude_plot(df):
def coverage_vs_amplitude_plot(df, save_path=None):
import seaborn as sns
import matplotlib.pyplot as plt
@ -243,7 +258,11 @@ def coverage_vs_amplitude_plot(df):
ax.axhline(0.95, linestyle='--', color='gray', alpha=0.7)
plt.tight_layout()
plt.show()
if save_path is None:
plt.show()
else:
os.makedirs(Path(save_path).parent, exist_ok=True)
plt.savefig(save_path)
if __name__ == '__main__':
@ -267,11 +286,12 @@ if __name__ == '__main__':
# concat all reports as a dataframe
df = concat_reports(reports)
for data_name in selected:
print(data_name)
df_ = df[df['dataset']==data_name]
error_vs_concentration_plot(df_)
coverage_vs_concentration_plot(df_)
amplitude_vs_concentration_plot(df_)
coverage_vs_amplitude_plot(df_)
# for data_name in selected:
# print(data_name)
# df_ = df[df['dataset']==data_name]
df_ = df
error_vs_concentration_plot(df_, save_path='./plots/prior_effect/error_vs_concentration.pdf')
coverage_vs_concentration_plot(df_, save_path='./plots/prior_effect/coverage_vs_concentration.pdf')
amplitude_vs_concentration_plot(df_, save_path='./plots/prior_effect/amplitude_vs_concentration.pdf')
coverage_vs_amplitude_plot(df_, save_path='./plots/prior_effect/coverage_vs_amplitude.pdf')