QuaPy/BayesianKDEy/temperature_calibration.py

115 lines
3.7 KiB
Python

from build.lib.quapy.data import LabelledCollection
from quapy.method.confidence import WithConfidenceABC
from quapy.protocol import UPP
import numpy as np
from tqdm import tqdm
import quapy as qp
from joblib import Parallel, delayed
import copy
def temp_calibration(method:WithConfidenceABC,
train:LabelledCollection,
val:LabelledCollection,
temp_grid=[.5, 1., 1.5, 2., 5., 10., 100.],
num_samples=100,
nominal_coverage=0.95,
amplitude_threshold='auto',
random_state=0,
n_jobs=1,
verbose=True):
assert (amplitude_threshold == 'auto' or (isinstance(amplitude_threshold, float)) and amplitude_threshold < 1.), \
f'wrong value for {amplitude_threshold=}, it must either be "auto" or a float < 1.0.'
if amplitude_threshold=='auto':
n_classes = train.n_classes
amplitude_threshold = .1/np.log(n_classes+1)
if isinstance(amplitude_threshold, float) and amplitude_threshold > 0.1:
print(f'warning: the {amplitude_threshold=} is too large; this may lead to uninformative regions')
method.fit(*train.Xy)
label_shift_prot = UPP(val, repeats=num_samples, random_state=random_state)
# results = []
# temp_grid = sorted(temp_grid)
# for temp in temp_grid:
# method.temperature = temp
# coverage = 0
# amplitudes = []
# errs = []
# pbar = tqdm(enumerate(label_shift_prot()), total=label_shift_prot.total(), disable=not verbose)
# for i, (sample, prev) in pbar:
# point_estim, conf_region = method.predict_conf(sample)
# if prev in conf_region:
# coverage += 1
# amplitudes.append(conf_region.montecarlo_proportion(n_trials=50_000))
# errs.append(qp.error.mae(prev, point_estim))
# if verbose:
# pbar.set_description(
# f'temperature={temp:.2f}, '
# f'coverage={coverage/(i+1):.4f}, '
# f'amplitude={np.mean(amplitudes):.4f},'
# f'mae={np.mean(errs):.4f}'
# )
#
# mean_coverage = coverage / label_shift_prot.total()
# mean_amplitude = np.mean(amplitudes)
#
# if mean_amplitude < amplitude_threshold:
# results.append((temp, mean_coverage, mean_amplitude))
# else:
# break
def evaluate_temperature(temp):
local_method = copy.deepcopy(method)
local_method.temperature = temp
coverage = 0
amplitudes = []
errs = []
for i, (sample, prev) in enumerate(label_shift_prot()):
point_estim, conf_region = local_method.predict_conf(sample)
if prev in conf_region:
coverage += 1
amplitudes.append(conf_region.montecarlo_proportion(n_trials=50_000))
errs.append(qp.error.mae(prev, point_estim))
mean_coverage = coverage / label_shift_prot.total()
mean_amplitude = np.mean(amplitudes)
return temp, mean_coverage, mean_amplitude
temp_grid = sorted(temp_grid)
raw_results = Parallel(n_jobs=n_jobs, backend="loky")(
delayed(evaluate_temperature)(temp)
for temp in tqdm(temp_grid, disable=not verbose)
)
results = [
(temp, cov, amp)
for temp, cov, amp in raw_results
if amp < amplitude_threshold
]
chosen_temperature = 1.
if len(results) > 0:
chosen_temperature = min(results, key=lambda x: abs(x[1]-nominal_coverage))[0]
print(f'chosen_temperature={chosen_temperature:.2f}')
return chosen_temperature