from build.lib.quapy.data import LabelledCollection from quapy.method.confidence import WithConfidenceABC from quapy.protocol import UPP import numpy as np from tqdm import tqdm import quapy as qp from joblib import Parallel, delayed import copy def temp_calibration(method:WithConfidenceABC, train:LabelledCollection, val:LabelledCollection, temp_grid=[.5, 1., 1.5, 2., 5., 10., 100.], num_samples=100, nominal_coverage=0.95, amplitude_threshold='auto', random_state=0, n_jobs=1, verbose=True): assert (amplitude_threshold == 'auto' or (isinstance(amplitude_threshold, float)) and amplitude_threshold < 1.), \ f'wrong value for {amplitude_threshold=}, it must either be "auto" or a float < 1.0.' if amplitude_threshold=='auto': n_classes = train.n_classes amplitude_threshold = .1/np.log(n_classes+1) if isinstance(amplitude_threshold, float) and amplitude_threshold > 0.1: print(f'warning: the {amplitude_threshold=} is too large; this may lead to uninformative regions') method.fit(*train.Xy) label_shift_prot = UPP(val, repeats=num_samples, random_state=random_state) # results = [] # temp_grid = sorted(temp_grid) # for temp in temp_grid: # method.temperature = temp # coverage = 0 # amplitudes = [] # errs = [] # pbar = tqdm(enumerate(label_shift_prot()), total=label_shift_prot.total(), disable=not verbose) # for i, (sample, prev) in pbar: # point_estim, conf_region = method.predict_conf(sample) # if prev in conf_region: # coverage += 1 # amplitudes.append(conf_region.montecarlo_proportion(n_trials=50_000)) # errs.append(qp.error.mae(prev, point_estim)) # if verbose: # pbar.set_description( # f'temperature={temp:.2f}, ' # f'coverage={coverage/(i+1):.4f}, ' # f'amplitude={np.mean(amplitudes):.4f},' # f'mae={np.mean(errs):.4f}' # ) # # mean_coverage = coverage / label_shift_prot.total() # mean_amplitude = np.mean(amplitudes) # # if mean_amplitude < amplitude_threshold: # results.append((temp, mean_coverage, mean_amplitude)) # else: # break def evaluate_temperature(temp): local_method = copy.deepcopy(method) local_method.temperature = temp coverage = 0 amplitudes = [] errs = [] for i, (sample, prev) in enumerate(label_shift_prot()): point_estim, conf_region = local_method.predict_conf(sample) if prev in conf_region: coverage += 1 amplitudes.append(conf_region.montecarlo_proportion(n_trials=50_000)) errs.append(qp.error.mae(prev, point_estim)) mean_coverage = coverage / label_shift_prot.total() mean_amplitude = np.mean(amplitudes) return temp, mean_coverage, mean_amplitude temp_grid = sorted(temp_grid) raw_results = Parallel(n_jobs=n_jobs, backend="loky")( delayed(evaluate_temperature)(temp) for temp in tqdm(temp_grid, disable=not verbose) ) results = [ (temp, cov, amp) for temp, cov, amp in raw_results if amp < amplitude_threshold ] chosen_temperature = 1. if len(results) > 0: chosen_temperature = min(results, key=lambda x: abs(x[1]-nominal_coverage))[0] print(f'chosen_temperature={chosen_temperature:.2f}') return chosen_temperature