reducing strength for antagonic prior
This commit is contained in:
parent
93c33fe237
commit
a511f577c9
|
|
@ -4,11 +4,17 @@
|
|||
- analyze across shift
|
||||
- add Bayesian EM:
|
||||
- https://github.com/ChangkunYe/MAPLS/blob/main/label_shift/mapls.py
|
||||
- take this opportunity to add RLLS: https://github.com/ChangkunYe/MAPLS/blob/main/label_shift/rlls.py
|
||||
- take this opportunity to add RLLS:
|
||||
https://github.com/Angie-Liu/labelshift
|
||||
https://github.com/ChangkunYe/MAPLS/blob/main/label_shift/rlls.py
|
||||
- add CIFAR10 and MNIST? Maybe consider also previously tested types of shift (tweak-one-out, etc.)? from RLLS paper
|
||||
- https://github.com/Angie-Liu/labelshift/tree/master
|
||||
- https://github.com/Angie-Liu/labelshift/blob/master/cifar10_for_labelshift.py
|
||||
- Note: MNIST is downloadable from https://archive.ics.uci.edu/dataset/683/mnist+database+of+handwritten+digits
|
||||
- Seem to be some pretrained models in:
|
||||
https://github.com/geifmany/cifar-vgg
|
||||
https://github.com/EN10/KerasMNIST
|
||||
https://github.com/tohinz/SVHN-Classifier
|
||||
- consider prior knowledge in experiments:
|
||||
- One scenario in which our prior is uninformative (i.e., uniform)
|
||||
- One scenario in which our prior is wrong (e.g., alpha-prior = (3,2,1), protocol-prior=(1,1,5))
|
||||
|
|
|
|||
|
|
@ -39,6 +39,9 @@ def methods():
|
|||
yield f'BaKDE-Ait-T2', KDEyCLR(LR()), kdey_hyper_clr, lambda hyper: BayesianKDEy(kernel='aitchison', mcmc_seed=0,
|
||||
engine='numpyro', temperature=2.,
|
||||
prior='uniform', **hyper)
|
||||
yield f'BaKDE-Ait-T1', KDEyCLR(LR()), kdey_hyper_clr, lambda hyper: BayesianKDEy(kernel='aitchison', mcmc_seed=0,
|
||||
engine='numpyro', temperature=1.,
|
||||
prior='uniform', **hyper)
|
||||
|
||||
|
||||
def run_test(test, alpha_test, alpha_train, concentration, prior_type, bay_quant, train_prev, dataset_name, method_name, results):
|
||||
|
|
@ -104,7 +107,7 @@ def experiment(dataset: Dataset,
|
|||
run_test(test, alpha_test_informative, alpha_train, concentration, prior_type, bay_quant, train_prev, dataset_name, method_name, results)
|
||||
|
||||
# informative prior
|
||||
alpha_test_wrong = antagonistic_prevalence(train_prev, strength=1) * concentration
|
||||
alpha_test_wrong = antagonistic_prevalence(train_prev, strength=0.5) * concentration
|
||||
prior_type = 'wrong'
|
||||
run_test(test, alpha_test_wrong, alpha_train, concentration, prior_type, bay_quant, train_prev, dataset_name, method_name, results)
|
||||
|
||||
|
|
@ -120,7 +123,7 @@ def concat_reports(reports):
|
|||
return df
|
||||
|
||||
|
||||
def error_vs_concentration_plot(df, err='ae'):
|
||||
def error_vs_concentration_plot(df, err='ae', save_path=None):
|
||||
import seaborn as sns
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
|
|
@ -147,10 +150,14 @@ def error_vs_concentration_plot(df, err='ae'):
|
|||
ax.set_ylabel('M'+err.upper())
|
||||
|
||||
plt.tight_layout()
|
||||
plt.show()
|
||||
if save_path is None:
|
||||
plt.show()
|
||||
else:
|
||||
os.makedirs(Path(save_path).parent, exist_ok=True)
|
||||
plt.savefig(save_path)
|
||||
|
||||
|
||||
def coverage_vs_concentration_plot(df):
|
||||
def coverage_vs_concentration_plot(df, save_path=None):
|
||||
import seaborn as sns
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
|
|
@ -176,10 +183,14 @@ def coverage_vs_concentration_plot(df):
|
|||
ax.set_ylabel('Coverage')
|
||||
|
||||
plt.tight_layout()
|
||||
plt.show()
|
||||
if save_path is None:
|
||||
plt.show()
|
||||
else:
|
||||
os.makedirs(Path(save_path).parent, exist_ok=True)
|
||||
plt.savefig(save_path)
|
||||
|
||||
|
||||
def amplitude_vs_concentration_plot(df):
|
||||
def amplitude_vs_concentration_plot(df, save_path=None):
|
||||
import seaborn as sns
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
|
|
@ -204,10 +215,14 @@ def amplitude_vs_concentration_plot(df):
|
|||
ax.set_ylabel('Amplitude')
|
||||
|
||||
plt.tight_layout()
|
||||
plt.show()
|
||||
if save_path is None:
|
||||
plt.show()
|
||||
else:
|
||||
os.makedirs(Path(save_path).parent, exist_ok=True)
|
||||
plt.savefig(save_path)
|
||||
|
||||
|
||||
def coverage_vs_amplitude_plot(df):
|
||||
def coverage_vs_amplitude_plot(df, save_path=None):
|
||||
import seaborn as sns
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
|
|
@ -243,7 +258,11 @@ def coverage_vs_amplitude_plot(df):
|
|||
ax.axhline(0.95, linestyle='--', color='gray', alpha=0.7)
|
||||
|
||||
plt.tight_layout()
|
||||
plt.show()
|
||||
if save_path is None:
|
||||
plt.show()
|
||||
else:
|
||||
os.makedirs(Path(save_path).parent, exist_ok=True)
|
||||
plt.savefig(save_path)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
@ -267,11 +286,12 @@ if __name__ == '__main__':
|
|||
# concat all reports as a dataframe
|
||||
df = concat_reports(reports)
|
||||
|
||||
for data_name in selected:
|
||||
print(data_name)
|
||||
df_ = df[df['dataset']==data_name]
|
||||
error_vs_concentration_plot(df_)
|
||||
coverage_vs_concentration_plot(df_)
|
||||
amplitude_vs_concentration_plot(df_)
|
||||
coverage_vs_amplitude_plot(df_)
|
||||
# for data_name in selected:
|
||||
# print(data_name)
|
||||
# df_ = df[df['dataset']==data_name]
|
||||
df_ = df
|
||||
error_vs_concentration_plot(df_, save_path='./plots/prior_effect/error_vs_concentration.pdf')
|
||||
coverage_vs_concentration_plot(df_, save_path='./plots/prior_effect/coverage_vs_concentration.pdf')
|
||||
amplitude_vs_concentration_plot(df_, save_path='./plots/prior_effect/amplitude_vs_concentration.pdf')
|
||||
coverage_vs_amplitude_plot(df_, save_path='./plots/prior_effect/coverage_vs_amplitude.pdf')
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue