115 lines
3.7 KiB
Python
115 lines
3.7 KiB
Python
from build.lib.quapy.data import LabelledCollection
|
|
from quapy.method.confidence import WithConfidenceABC
|
|
from quapy.protocol import UPP
|
|
import numpy as np
|
|
from tqdm import tqdm
|
|
import quapy as qp
|
|
from joblib import Parallel, delayed
|
|
import copy
|
|
|
|
|
|
|
|
|
|
def temp_calibration(method:WithConfidenceABC,
|
|
train:LabelledCollection,
|
|
val:LabelledCollection,
|
|
temp_grid=[.5, 1., 1.5, 2., 5., 10., 100.],
|
|
num_samples=100,
|
|
nominal_coverage=0.95,
|
|
amplitude_threshold='auto',
|
|
random_state=0,
|
|
n_jobs=1,
|
|
verbose=True):
|
|
|
|
assert (amplitude_threshold == 'auto' or (isinstance(amplitude_threshold, float)) and amplitude_threshold < 1.), \
|
|
f'wrong value for {amplitude_threshold=}, it must either be "auto" or a float < 1.0.'
|
|
|
|
if amplitude_threshold=='auto':
|
|
n_classes = train.n_classes
|
|
amplitude_threshold = .1/np.log(n_classes+1)
|
|
|
|
if isinstance(amplitude_threshold, float) and amplitude_threshold > 0.1:
|
|
print(f'warning: the {amplitude_threshold=} is too large; this may lead to uninformative regions')
|
|
|
|
|
|
method.fit(*train.Xy)
|
|
label_shift_prot = UPP(val, repeats=num_samples, random_state=random_state)
|
|
|
|
# results = []
|
|
# temp_grid = sorted(temp_grid)
|
|
# for temp in temp_grid:
|
|
# method.temperature = temp
|
|
# coverage = 0
|
|
# amplitudes = []
|
|
# errs = []
|
|
# pbar = tqdm(enumerate(label_shift_prot()), total=label_shift_prot.total(), disable=not verbose)
|
|
# for i, (sample, prev) in pbar:
|
|
# point_estim, conf_region = method.predict_conf(sample)
|
|
# if prev in conf_region:
|
|
# coverage += 1
|
|
# amplitudes.append(conf_region.montecarlo_proportion(n_trials=50_000))
|
|
# errs.append(qp.error.mae(prev, point_estim))
|
|
# if verbose:
|
|
# pbar.set_description(
|
|
# f'temperature={temp:.2f}, '
|
|
# f'coverage={coverage/(i+1):.4f}, '
|
|
# f'amplitude={np.mean(amplitudes):.4f},'
|
|
# f'mae={np.mean(errs):.4f}'
|
|
# )
|
|
#
|
|
# mean_coverage = coverage / label_shift_prot.total()
|
|
# mean_amplitude = np.mean(amplitudes)
|
|
#
|
|
# if mean_amplitude < amplitude_threshold:
|
|
# results.append((temp, mean_coverage, mean_amplitude))
|
|
# else:
|
|
# break
|
|
|
|
def evaluate_temperature(temp):
|
|
local_method = copy.deepcopy(method)
|
|
local_method.temperature = temp
|
|
|
|
coverage = 0
|
|
amplitudes = []
|
|
errs = []
|
|
|
|
for i, (sample, prev) in enumerate(label_shift_prot()):
|
|
point_estim, conf_region = local_method.predict_conf(sample)
|
|
|
|
if prev in conf_region:
|
|
coverage += 1
|
|
|
|
amplitudes.append(conf_region.montecarlo_proportion(n_trials=50_000))
|
|
errs.append(qp.error.mae(prev, point_estim))
|
|
|
|
mean_coverage = coverage / label_shift_prot.total()
|
|
mean_amplitude = np.mean(amplitudes)
|
|
|
|
return temp, mean_coverage, mean_amplitude
|
|
|
|
temp_grid = sorted(temp_grid)
|
|
|
|
raw_results = Parallel(n_jobs=n_jobs, backend="loky")(
|
|
delayed(evaluate_temperature)(temp)
|
|
for temp in tqdm(temp_grid, disable=not verbose)
|
|
)
|
|
results = [
|
|
(temp, cov, amp)
|
|
for temp, cov, amp in raw_results
|
|
if amp < amplitude_threshold
|
|
]
|
|
|
|
chosen_temperature = 1.
|
|
if len(results) > 0:
|
|
chosen_temperature = min(results, key=lambda x: abs(x[1]-nominal_coverage))[0]
|
|
|
|
print(f'chosen_temperature={chosen_temperature:.2f}')
|
|
|
|
return chosen_temperature
|
|
|
|
|
|
|
|
|
|
|
|
|