117 lines
4.1 KiB
Python
117 lines
4.1 KiB
Python
from build.lib.quapy.data import LabelledCollection
|
|
from quapy.method.confidence import WithConfidenceABC
|
|
from quapy.protocol import AbstractProtocol
|
|
import numpy as np
|
|
from tqdm import tqdm
|
|
import quapy as qp
|
|
from joblib import Parallel, delayed
|
|
import copy
|
|
|
|
|
|
def temp_calibration(method:WithConfidenceABC,
|
|
train:LabelledCollection,
|
|
val_prot:AbstractProtocol,
|
|
temp_grid=[.5, 1., 1.5, 2., 5., 10., 100.],
|
|
nominal_coverage=0.95,
|
|
amplitude_threshold=1.,
|
|
criterion='winkler',
|
|
n_jobs=1,
|
|
verbose=True):
|
|
|
|
assert (amplitude_threshold == 'auto' or (isinstance(amplitude_threshold, float)) and amplitude_threshold <= 1.), \
|
|
f'wrong value for {amplitude_threshold=}, it must either be "auto" or a float <= 1.0.'
|
|
assert criterion in {'auto', 'winkler'}, f'unknown {criterion=}; valid ones are auto or winkler'
|
|
|
|
if amplitude_threshold=='auto':
|
|
n_classes = train.n_classes
|
|
amplitude_threshold = .1/np.log(n_classes+1)
|
|
|
|
if isinstance(amplitude_threshold, float) and amplitude_threshold > 0.1:
|
|
print(f'warning: the {amplitude_threshold=} is too large; this may lead to uninformative regions')
|
|
|
|
def _evaluate_temperature_job(job_id, temp):
|
|
# if verbose:
|
|
# print(f'\tstarting exploration with temperature={temp}...')
|
|
|
|
local_method = copy.deepcopy(method)
|
|
local_method.temperature = temp
|
|
|
|
coverage = 0
|
|
amplitudes = []
|
|
winklers = []
|
|
# errs = []
|
|
|
|
pbar = tqdm(enumerate(val_prot()), position=job_id, total=val_prot.total(), disable=not verbose)
|
|
|
|
for i, (sample, prev) in pbar:
|
|
point_estim, conf_region = local_method.predict_conf(sample)
|
|
|
|
if prev in conf_region:
|
|
coverage += 1
|
|
|
|
amplitudes.append(conf_region.montecarlo_proportion(n_trials=50_000))
|
|
winkler = None
|
|
if criterion=='winkler':
|
|
winkler = conf_region.mean_winkler_score(true_prev=prev, alpha=0.005)
|
|
winklers.append(winkler)
|
|
|
|
# errs.append(qp.error.mae(prev, point_estim))
|
|
pbar.set_description(
|
|
f'job={job_id} T={temp}: '
|
|
f'coverage={coverage/(i+1)*100:.2f}% '
|
|
f'amplitude={np.mean(amplitudes)*100:.4f}% '
|
|
+ f'winkler={np.mean(winklers):.4f}%' if criterion=='winkler' else ''
|
|
)
|
|
|
|
mean_coverage = coverage / val_prot.total()
|
|
mean_amplitude = np.mean(amplitudes)
|
|
winkler_mean = np.mean(winklers) if criterion=='winkler' else None
|
|
|
|
# if verbose:
|
|
# print(
|
|
# f'Temperature={temp} got '
|
|
# f'coverage={mean_coverage*100:.2f}% '
|
|
# f'amplitude={mean_amplitude*100:.2f}% '
|
|
# + f'winkler={winkler_mean:.4f}' if criterion == 'winkler' else ''
|
|
# )
|
|
|
|
return temp, mean_coverage, mean_amplitude, winkler_mean
|
|
|
|
temp_grid = sorted(temp_grid)
|
|
method.fit(*train.Xy)
|
|
|
|
raw_results = Parallel(n_jobs=n_jobs, backend="loky")(
|
|
delayed(_evaluate_temperature_job)(job_id, temp)
|
|
for job_id, temp in tqdm(enumerate(temp_grid), disable=not verbose)
|
|
)
|
|
results = [
|
|
(temp, cov, amp, wink)
|
|
for temp, cov, amp, wink in raw_results
|
|
if amp < amplitude_threshold
|
|
]
|
|
|
|
chosen_temperature = 1.
|
|
if len(results) > 0:
|
|
if criterion=='winkler':
|
|
# choose min winkler
|
|
chosen_temperature, ccov, camp, cwink = min(results, key=lambda x: x[3])
|
|
else:
|
|
# choose best coverage (regardless of amplitude), i.e., closest to nominal
|
|
chosen_temperature, ccov, camp, cwink = min(results, key=lambda x: abs(x[1]-nominal_coverage))
|
|
|
|
if verbose:
|
|
print(
|
|
f'\nChosen_temperature={chosen_temperature:.2f} got '
|
|
f'coverage={ccov*100:.2f}% '
|
|
f'amplitude={camp*100:.4f}% '
|
|
+ f'winkler={cwink:.4f}' if criterion=='winkler' else ''
|
|
)
|
|
|
|
return chosen_temperature
|
|
|
|
|
|
|
|
|
|
|
|
|