autolike+ fixed

This commit is contained in:
Alejandro Moreo Fernandez 2024-09-25 17:23:43 +02:00
parent 531badffc8
commit 4fa4540aab
2 changed files with 14 additions and 20 deletions

View File

@ -308,41 +308,35 @@ class KDEyMLauto2(KDEyML):
prot = UPP(val, sample_size=self.reduction, repeats=repeats, random_state=self.random_state)
if self.target == 'likelihood+':
def neg_loglikelihood_band_(bandwidth):
bandwidth=bandwidth[0]
def neg_loglikelihood_bandwidth(bandwidth):
mix_densities = self.get_mixture_components(*train.Xy, train.classes_, bandwidth)
loss_accum = 0
for (sample, prevtrue) in prot():
test_densities2 = [self.pdf(kde_i, sample) for kde_i in mix_densities]
test_densities = [self.pdf(kde_i, sample) for kde_i in mix_densities]
def neg_loglikelihood_prev(prev):
test_mixture_likelihood = sum(prev_i * dens_i for prev_i, dens_i in zip(prev, test_densities2))
test_mixture_likelihood = sum(prev_i * dens_i for prev_i, dens_i in zip(prev, test_densities))
test_loglikelihood = np.log(test_mixture_likelihood + epsilon)
nll = -np.sum(test_loglikelihood)
# print(f'\t\tprev={F.strprev(prev)} got {nll=}')
return nll
init_prev = np.full(fill_value=1 / n_classes, shape=(n_classes,))
pred_prev, neglikelihood = optim_minimize(neg_loglikelihood_prev, init_prev, return_loss=True)
# print(f'\t\tprev={F.strprev(pred_prev)} (true={F.strprev(prev)}) got {neglikelihood=}')
loss_accum += neglikelihood
print(f'\t{bandwidth=:.8f} got {loss_accum=:.8f}')
return loss_accum
bounds = [tuple((0.0001, 0.2))]
init_bandwidth = 0.1
r = optimize.minimize(neg_loglikelihood_band_, x0=[init_bandwidth], method='Nelder-Mead', bounds=bounds, tol=1)
best_band = r.x[0]
best_loss_val = r.fun
r = optimize.minimize_scalar(neg_loglikelihood_bandwidth, bounds=(0.00001, 0.2))
best_band = r.x
best_loss_value = r.fun
nit = r.nit
assert r.success, 'Process did not converge!'
# assert r.success, 'Process did not converge!'
#found bandwidth=0.00994664 after nit=3 iterations loss_val=-212247.24305)
else:
best_band = None
best_loss_val = None
best_loss_value = None
init_prev = np.full(fill_value=1 / n_classes, shape=(n_classes,))
for bandwidth in np.logspace(-4, np.log10(0.2), 20):
mix_densities = self.get_mixture_components(*train.Xy, train.classes_, bandwidth)
@ -364,12 +358,12 @@ class KDEyMLauto2(KDEyML):
pred_prev, loss_val = optim_minimize(loss_fn, init_prev, return_loss=True)
loss_accum += loss_val
if best_loss_val is None or loss_accum < best_loss_val:
best_loss_val = loss_accum
if best_loss_value is None or loss_accum < best_loss_value:
best_loss_value = loss_accum
best_band = bandwidth
nit=20
print(f'found bandwidth={best_band:.8f} after {nit=} iterations loss_val={best_loss_val:.5f})')
print(f'found bandwidth={best_band:.8f} after {nit=} iterations loss_val={best_loss_value:.5f})')
self.bandwidth_ = best_band

View File

@ -39,7 +39,7 @@ METHODS = [
('KDEy-ML-scott', KDEyML(newLR(), bandwidth='scott'), wrap_hyper(logreg_grid)),
('KDEy-ML-silver', KDEyML(newLR(), bandwidth='silverman'), wrap_hyper(logreg_grid)),
('KDEy-ML-autoLike', KDEyMLauto2(newLR(), bandwidth='auto', target='likelihood'), wrap_hyper(logreg_grid)),
# ('KDEy-ML-autoLike+', KDEyMLauto2(newLR(), bandwidth='auto', target='likelihood+'), wrap_hyper(logreg_grid)), <-- no funciona
('KDEy-ML-autoLike+', KDEyMLauto2(newLR(), bandwidth='auto', target='likelihood+'), wrap_hyper(logreg_grid)),
('KDEy-ML-autoAE', KDEyMLauto2(newLR(), bandwidth='auto', target='mae'), wrap_hyper(logreg_grid)),
('KDEy-ML-autoRAE', KDEyMLauto2(newLR(), bandwidth='auto', target='mrae'), wrap_hyper(logreg_grid)),
]
@ -55,7 +55,7 @@ TRANSDUCTIVE_METHODS = [
#('TKDEy-ML', KDEyMLauto(newLR()), None),
# ('TKDEy-MLboth', KDEyMLauto(newLR(), optim='both'), None),
# ('TKDEy-MLbothfine', KDEyMLauto(newLR(), optim='both_fine'), None),
('TKDEy-ML2', KDEyMLauto(newLR(), optim='two_steps'), None),
# ('TKDEy-ML2', KDEyMLauto(newLR(), optim='two_steps'), None),
# ('TKDEy-MLike', KDEyMLauto(newLR(), optim='max_likelihood'), None),
# ('TKDEy-MLike2', KDEyMLauto(newLR(), optim='max_likelihood2'), None),
#('TKDEy-ML3', KDEyMLauto(newLR()), None),