bugfix
This commit is contained in:
parent
934b09fa66
commit
300b8e6423
|
|
@ -22,13 +22,13 @@ def fetch_UCI_binary(data_name):
|
|||
# global configurations
|
||||
|
||||
binary = {
|
||||
'datasets': qp.datasets.UCI_BINARY_DATASETS,
|
||||
'datasets': qp.datasets.UCI_BINARY_DATASETS.copy(),
|
||||
'fetch_fn': fetch_UCI_binary,
|
||||
'sample_size': 500
|
||||
}
|
||||
|
||||
multiclass = {
|
||||
'datasets': qp.datasets.UCI_MULTICLASS_DATASETS,
|
||||
'datasets': qp.datasets.UCI_MULTICLASS_DATASETS.copy(),
|
||||
'fetch_fn': fetch_UCI_multiclass,
|
||||
'sample_size': 1000
|
||||
}
|
||||
|
|
|
|||
|
|
@ -14,6 +14,7 @@ from sklearn.linear_model import LogisticRegression as LR
|
|||
from copy import deepcopy as cp
|
||||
from tqdm import tqdm
|
||||
from full_experiments import model_selection
|
||||
from itertools import chain
|
||||
|
||||
|
||||
def select_imbalanced_datasets(top_m=5):
|
||||
|
|
@ -107,13 +108,52 @@ def experiment(dataset: Dataset,
|
|||
return results
|
||||
|
||||
|
||||
def concat_reports(reports):
|
||||
final_report = {
|
||||
k: list(chain.from_iterable(report[k] for report in reports))
|
||||
for k in reports[0]
|
||||
}
|
||||
df = pd.DataFrame(final_report)
|
||||
return df
|
||||
|
||||
|
||||
def plot_results(df):
|
||||
import seaborn as sns
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
sns.set_theme(style="whitegrid", context="paper")
|
||||
|
||||
fig, axes = plt.subplots(1, 2, figsize=(10, 4), sharey=True)
|
||||
|
||||
for ax, prior in zip(axes, ['informative', 'wrong']):
|
||||
sub = df[df['prior-type'] == prior]
|
||||
|
||||
sns.lineplot(
|
||||
data=sub,
|
||||
x='concentration',
|
||||
y='ae',
|
||||
hue='method_name',
|
||||
marker='o',
|
||||
errorbar='se', # o 'sd'
|
||||
ax=ax
|
||||
)
|
||||
|
||||
ax.set_xscale('log')
|
||||
ax.set_title(f'Prior: {prior}')
|
||||
ax.set_xlabel('Concentration')
|
||||
ax.set_ylabel('MAE')
|
||||
|
||||
plt.tight_layout()
|
||||
plt.show()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
result_dir = Path('./results/prior_effect')
|
||||
selected = select_imbalanced_datasets()
|
||||
print(f'selected datasets={selected}')
|
||||
qp.environ['SAMPLE_SIZE'] = multiclass['sample_size']
|
||||
reports = []
|
||||
for data_name in selected:
|
||||
for data_name in selected[:2]:
|
||||
data = multiclass['fetch_fn'](data_name)
|
||||
for method_name, surrogate_quant, hyper_params, bay_constructor in methods():
|
||||
result_path = experiment_path(result_dir, data_name, method_name)
|
||||
|
|
@ -125,6 +165,9 @@ if __name__ == '__main__':
|
|||
)
|
||||
reports.append(report)
|
||||
|
||||
# concat all reports as a dataframe
|
||||
df = concat_reports(reports)
|
||||
|
||||
plot_results(df)
|
||||
|
||||
# df = pd.DataFrame(results)
|
||||
print('ONLY TWO DATASETS')
|
||||
Loading…
Reference in New Issue