reducing strength for antagonic prior
This commit is contained in:
parent
93c33fe237
commit
a511f577c9
|
|
@ -4,11 +4,17 @@
|
||||||
- analyze across shift
|
- analyze across shift
|
||||||
- add Bayesian EM:
|
- add Bayesian EM:
|
||||||
- https://github.com/ChangkunYe/MAPLS/blob/main/label_shift/mapls.py
|
- https://github.com/ChangkunYe/MAPLS/blob/main/label_shift/mapls.py
|
||||||
- take this opportunity to add RLLS: https://github.com/ChangkunYe/MAPLS/blob/main/label_shift/rlls.py
|
- take this opportunity to add RLLS:
|
||||||
|
https://github.com/Angie-Liu/labelshift
|
||||||
|
https://github.com/ChangkunYe/MAPLS/blob/main/label_shift/rlls.py
|
||||||
- add CIFAR10 and MNIST? Maybe consider also previously tested types of shift (tweak-one-out, etc.)? from RLLS paper
|
- add CIFAR10 and MNIST? Maybe consider also previously tested types of shift (tweak-one-out, etc.)? from RLLS paper
|
||||||
- https://github.com/Angie-Liu/labelshift/tree/master
|
- https://github.com/Angie-Liu/labelshift/tree/master
|
||||||
- https://github.com/Angie-Liu/labelshift/blob/master/cifar10_for_labelshift.py
|
- https://github.com/Angie-Liu/labelshift/blob/master/cifar10_for_labelshift.py
|
||||||
- Note: MNIST is downloadable from https://archive.ics.uci.edu/dataset/683/mnist+database+of+handwritten+digits
|
- Note: MNIST is downloadable from https://archive.ics.uci.edu/dataset/683/mnist+database+of+handwritten+digits
|
||||||
|
- Seem to be some pretrained models in:
|
||||||
|
https://github.com/geifmany/cifar-vgg
|
||||||
|
https://github.com/EN10/KerasMNIST
|
||||||
|
https://github.com/tohinz/SVHN-Classifier
|
||||||
- consider prior knowledge in experiments:
|
- consider prior knowledge in experiments:
|
||||||
- One scenario in which our prior is uninformative (i.e., uniform)
|
- One scenario in which our prior is uninformative (i.e., uniform)
|
||||||
- One scenario in which our prior is wrong (e.g., alpha-prior = (3,2,1), protocol-prior=(1,1,5))
|
- One scenario in which our prior is wrong (e.g., alpha-prior = (3,2,1), protocol-prior=(1,1,5))
|
||||||
|
|
|
||||||
|
|
@ -39,6 +39,9 @@ def methods():
|
||||||
yield f'BaKDE-Ait-T2', KDEyCLR(LR()), kdey_hyper_clr, lambda hyper: BayesianKDEy(kernel='aitchison', mcmc_seed=0,
|
yield f'BaKDE-Ait-T2', KDEyCLR(LR()), kdey_hyper_clr, lambda hyper: BayesianKDEy(kernel='aitchison', mcmc_seed=0,
|
||||||
engine='numpyro', temperature=2.,
|
engine='numpyro', temperature=2.,
|
||||||
prior='uniform', **hyper)
|
prior='uniform', **hyper)
|
||||||
|
yield f'BaKDE-Ait-T1', KDEyCLR(LR()), kdey_hyper_clr, lambda hyper: BayesianKDEy(kernel='aitchison', mcmc_seed=0,
|
||||||
|
engine='numpyro', temperature=1.,
|
||||||
|
prior='uniform', **hyper)
|
||||||
|
|
||||||
|
|
||||||
def run_test(test, alpha_test, alpha_train, concentration, prior_type, bay_quant, train_prev, dataset_name, method_name, results):
|
def run_test(test, alpha_test, alpha_train, concentration, prior_type, bay_quant, train_prev, dataset_name, method_name, results):
|
||||||
|
|
@ -104,7 +107,7 @@ def experiment(dataset: Dataset,
|
||||||
run_test(test, alpha_test_informative, alpha_train, concentration, prior_type, bay_quant, train_prev, dataset_name, method_name, results)
|
run_test(test, alpha_test_informative, alpha_train, concentration, prior_type, bay_quant, train_prev, dataset_name, method_name, results)
|
||||||
|
|
||||||
# informative prior
|
# informative prior
|
||||||
alpha_test_wrong = antagonistic_prevalence(train_prev, strength=1) * concentration
|
alpha_test_wrong = antagonistic_prevalence(train_prev, strength=0.5) * concentration
|
||||||
prior_type = 'wrong'
|
prior_type = 'wrong'
|
||||||
run_test(test, alpha_test_wrong, alpha_train, concentration, prior_type, bay_quant, train_prev, dataset_name, method_name, results)
|
run_test(test, alpha_test_wrong, alpha_train, concentration, prior_type, bay_quant, train_prev, dataset_name, method_name, results)
|
||||||
|
|
||||||
|
|
@ -120,7 +123,7 @@ def concat_reports(reports):
|
||||||
return df
|
return df
|
||||||
|
|
||||||
|
|
||||||
def error_vs_concentration_plot(df, err='ae'):
|
def error_vs_concentration_plot(df, err='ae', save_path=None):
|
||||||
import seaborn as sns
|
import seaborn as sns
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
|
@ -147,10 +150,14 @@ def error_vs_concentration_plot(df, err='ae'):
|
||||||
ax.set_ylabel('M'+err.upper())
|
ax.set_ylabel('M'+err.upper())
|
||||||
|
|
||||||
plt.tight_layout()
|
plt.tight_layout()
|
||||||
plt.show()
|
if save_path is None:
|
||||||
|
plt.show()
|
||||||
|
else:
|
||||||
|
os.makedirs(Path(save_path).parent, exist_ok=True)
|
||||||
|
plt.savefig(save_path)
|
||||||
|
|
||||||
|
|
||||||
def coverage_vs_concentration_plot(df):
|
def coverage_vs_concentration_plot(df, save_path=None):
|
||||||
import seaborn as sns
|
import seaborn as sns
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
|
@ -176,10 +183,14 @@ def coverage_vs_concentration_plot(df):
|
||||||
ax.set_ylabel('Coverage')
|
ax.set_ylabel('Coverage')
|
||||||
|
|
||||||
plt.tight_layout()
|
plt.tight_layout()
|
||||||
plt.show()
|
if save_path is None:
|
||||||
|
plt.show()
|
||||||
|
else:
|
||||||
|
os.makedirs(Path(save_path).parent, exist_ok=True)
|
||||||
|
plt.savefig(save_path)
|
||||||
|
|
||||||
|
|
||||||
def amplitude_vs_concentration_plot(df):
|
def amplitude_vs_concentration_plot(df, save_path=None):
|
||||||
import seaborn as sns
|
import seaborn as sns
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
|
@ -204,10 +215,14 @@ def amplitude_vs_concentration_plot(df):
|
||||||
ax.set_ylabel('Amplitude')
|
ax.set_ylabel('Amplitude')
|
||||||
|
|
||||||
plt.tight_layout()
|
plt.tight_layout()
|
||||||
plt.show()
|
if save_path is None:
|
||||||
|
plt.show()
|
||||||
|
else:
|
||||||
|
os.makedirs(Path(save_path).parent, exist_ok=True)
|
||||||
|
plt.savefig(save_path)
|
||||||
|
|
||||||
|
|
||||||
def coverage_vs_amplitude_plot(df):
|
def coverage_vs_amplitude_plot(df, save_path=None):
|
||||||
import seaborn as sns
|
import seaborn as sns
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
|
@ -243,7 +258,11 @@ def coverage_vs_amplitude_plot(df):
|
||||||
ax.axhline(0.95, linestyle='--', color='gray', alpha=0.7)
|
ax.axhline(0.95, linestyle='--', color='gray', alpha=0.7)
|
||||||
|
|
||||||
plt.tight_layout()
|
plt.tight_layout()
|
||||||
plt.show()
|
if save_path is None:
|
||||||
|
plt.show()
|
||||||
|
else:
|
||||||
|
os.makedirs(Path(save_path).parent, exist_ok=True)
|
||||||
|
plt.savefig(save_path)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|
@ -267,11 +286,12 @@ if __name__ == '__main__':
|
||||||
# concat all reports as a dataframe
|
# concat all reports as a dataframe
|
||||||
df = concat_reports(reports)
|
df = concat_reports(reports)
|
||||||
|
|
||||||
for data_name in selected:
|
# for data_name in selected:
|
||||||
print(data_name)
|
# print(data_name)
|
||||||
df_ = df[df['dataset']==data_name]
|
# df_ = df[df['dataset']==data_name]
|
||||||
error_vs_concentration_plot(df_)
|
df_ = df
|
||||||
coverage_vs_concentration_plot(df_)
|
error_vs_concentration_plot(df_, save_path='./plots/prior_effect/error_vs_concentration.pdf')
|
||||||
amplitude_vs_concentration_plot(df_)
|
coverage_vs_concentration_plot(df_, save_path='./plots/prior_effect/coverage_vs_concentration.pdf')
|
||||||
coverage_vs_amplitude_plot(df_)
|
amplitude_vs_concentration_plot(df_, save_path='./plots/prior_effect/amplitude_vs_concentration.pdf')
|
||||||
|
coverage_vs_amplitude_plot(df_, save_path='./plots/prior_effect/coverage_vs_amplitude.pdf')
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue