This commit is contained in:
Alejandro Moreo Fernandez 2026-01-13 15:01:02 +01:00
parent 934b09fa66
commit 300b8e6423
2 changed files with 47 additions and 4 deletions

View File

@ -22,13 +22,13 @@ def fetch_UCI_binary(data_name):
# global configurations
binary = {
'datasets': qp.datasets.UCI_BINARY_DATASETS,
'datasets': qp.datasets.UCI_BINARY_DATASETS.copy(),
'fetch_fn': fetch_UCI_binary,
'sample_size': 500
}
multiclass = {
'datasets': qp.datasets.UCI_MULTICLASS_DATASETS,
'datasets': qp.datasets.UCI_MULTICLASS_DATASETS.copy(),
'fetch_fn': fetch_UCI_multiclass,
'sample_size': 1000
}

View File

@ -14,6 +14,7 @@ from sklearn.linear_model import LogisticRegression as LR
from copy import deepcopy as cp
from tqdm import tqdm
from full_experiments import model_selection
from itertools import chain
def select_imbalanced_datasets(top_m=5):
@ -107,13 +108,52 @@ def experiment(dataset: Dataset,
return results
def concat_reports(reports):
final_report = {
k: list(chain.from_iterable(report[k] for report in reports))
for k in reports[0]
}
df = pd.DataFrame(final_report)
return df
def plot_results(df):
import seaborn as sns
import matplotlib.pyplot as plt
sns.set_theme(style="whitegrid", context="paper")
fig, axes = plt.subplots(1, 2, figsize=(10, 4), sharey=True)
for ax, prior in zip(axes, ['informative', 'wrong']):
sub = df[df['prior-type'] == prior]
sns.lineplot(
data=sub,
x='concentration',
y='ae',
hue='method_name',
marker='o',
errorbar='se', # o 'sd'
ax=ax
)
ax.set_xscale('log')
ax.set_title(f'Prior: {prior}')
ax.set_xlabel('Concentration')
ax.set_ylabel('MAE')
plt.tight_layout()
plt.show()
if __name__ == '__main__':
result_dir = Path('./results/prior_effect')
selected = select_imbalanced_datasets()
print(f'selected datasets={selected}')
qp.environ['SAMPLE_SIZE'] = multiclass['sample_size']
reports = []
for data_name in selected:
for data_name in selected[:2]:
data = multiclass['fetch_fn'](data_name)
for method_name, surrogate_quant, hyper_params, bay_constructor in methods():
result_path = experiment_path(result_dir, data_name, method_name)
@ -125,6 +165,9 @@ if __name__ == '__main__':
)
reports.append(report)
# concat all reports as a dataframe
df = concat_reports(reports)
plot_results(df)
# df = pd.DataFrame(results)
print('ONLY TWO DATASETS')