130 lines
3.4 KiB
Python
130 lines
3.4 KiB
Python
import os.path
|
|
import pickle
|
|
from collections import defaultdict
|
|
from pathlib import Path
|
|
import pandas as pd
|
|
import quapy as qp
|
|
from BayesianKDEy._bayeisan_kdey import BayesianKDEy
|
|
from BayesianKDEy.full_experiments import experiment_path
|
|
from quapy.protocol import UPP
|
|
import numpy as np
|
|
from tqdm import tqdm
|
|
from joblib import Parallel, delayed
|
|
from itertools import product
|
|
|
|
# is there a correspondence between smaller bandwidths and overconfident likelihoods? if so, a high temperature
|
|
# after calibration might be correlated; this script aims at analyzing this trend
|
|
|
|
datasets = qp.datasets.UCI_MULTICLASS_DATASETS
|
|
|
|
def show(results, values):
|
|
df_res = pd.DataFrame(results)
|
|
df_mean = (
|
|
df_res
|
|
.groupby(['bandwidth', 'temperature'], as_index=False)
|
|
.mean(numeric_only=True)
|
|
)
|
|
pivot = df_mean.pivot(
|
|
index='bandwidth',
|
|
columns='temperature',
|
|
values=values
|
|
)
|
|
import matplotlib.pyplot as plt
|
|
|
|
|
|
plt.imshow(pivot, origin='lower', aspect='auto')
|
|
plt.colorbar(label=values)
|
|
|
|
plt.xticks(range(len(pivot.columns)), pivot.columns)
|
|
plt.yticks(range(len(pivot.index)), pivot.index)
|
|
|
|
plt.xlabel('Temperature')
|
|
plt.ylabel('Bandwidth')
|
|
plt.title(f'{values} heatmap')
|
|
|
|
plt.savefig(f'./plotcorr/{values}.png')
|
|
plt.cla()
|
|
plt.clf()
|
|
|
|
def run_experiment(
|
|
bandwidth,
|
|
temperature,
|
|
train,
|
|
test,
|
|
dataset,
|
|
):
|
|
qp.environ['SAMPLE_SIZE'] = 1000
|
|
bakde = BayesianKDEy(
|
|
engine='numpyro',
|
|
bandwidth=bandwidth,
|
|
temperature=temperature,
|
|
)
|
|
bakde.fit(*train.Xy)
|
|
|
|
test_generator = UPP(test, repeats=20, random_state=0)
|
|
|
|
rows = []
|
|
|
|
for i, (sample, prev) in enumerate(
|
|
tqdm(
|
|
test_generator(),
|
|
desc=f"bw={bandwidth}, T={temperature}",
|
|
total=test_generator.total(),
|
|
leave=False,
|
|
)
|
|
):
|
|
point_estimate, region = bakde.predict_conf(sample)
|
|
|
|
rows.append({
|
|
"bandwidth": bandwidth,
|
|
"temperature": temperature,
|
|
"dataset": dataset,
|
|
"sample": i,
|
|
"mae": qp.error.mae(prev, point_estimate),
|
|
"cov": region.coverage(prev),
|
|
"amp": region.montecarlo_proportion(n_trials=50_000),
|
|
})
|
|
|
|
return rows
|
|
|
|
|
|
bandwidths = [0.001, 0.005, 0.01, 0.05, 0.1]
|
|
temperatures = [0.5, 0.75, 1., 2., 5.]
|
|
|
|
res_dir = './plotcorr/results'
|
|
os.makedirs(res_dir, exist_ok=True)
|
|
|
|
all_rows = []
|
|
for i, dataset in enumerate(datasets):
|
|
if dataset in ['letter', 'isolet']: continue
|
|
res_path = f'{res_dir}/{dataset}.pkl'
|
|
if os.path.exists(res_path):
|
|
print(f'loading results from pickle {res_path}')
|
|
results_data_rows = pickle.load(open(res_path, 'rb'))
|
|
else:
|
|
print(f'{dataset=} [complete={i}/{len(datasets)}]')
|
|
data = qp.datasets.fetch_UCIMulticlassDataset(dataset)
|
|
train, test = data.train_test
|
|
jobs = list(product(bandwidths, temperatures))
|
|
results_data_rows = Parallel(n_jobs=-1,backend="loky")(
|
|
delayed(run_experiment)(bw, T, train, test, dataset) for bw, T in jobs
|
|
)
|
|
pickle.dump(results_data_rows, open(res_path, 'wb'), pickle.HIGHEST_PROTOCOL)
|
|
all_rows.extend(results_data_rows)
|
|
|
|
results = defaultdict(list)
|
|
for rows in all_rows:
|
|
for row in rows:
|
|
for k, v in row.items():
|
|
results[k].append(v)
|
|
|
|
show(results, 'mae')
|
|
show(results, 'cov')
|
|
show(results, 'amp')
|
|
|
|
|
|
|
|
|
|
|
|
|