1
0
Fork 0
QuaPy/distribution_matching/tmp/approximating_divergence_mo...

83 lines
2.2 KiB
Python

import numpy as np
from scipy.stats import norm, uniform, expon
from scipy.special import rel_entr
def kld_approx(p, q, resolution=100):
steps = np.linspace(-5, 5, resolution)
integral = 0
for step_i1, step_i2 in zip(steps[:-1], steps[1:]):
base = step_i2-step_i1
middle = (step_i1+step_i2)/2
pi = p.pdf(middle)
qi = q.pdf(middle)
integral += (base * pi*np.log(pi/qi))
return integral
def integrate(f, resolution=10000):
steps = np.linspace(-10, 10, resolution)
integral = 0
for step_i1, step_i2 in zip(steps[:-1], steps[1:]):
base = step_i2 - step_i1
middle = (step_i1 + step_i2) / 2
integral += (base * f(middle))
return integral
def kl_analytic(m1, s1, m2, s2):
return np.log(s2/s1)+ (s1*s1 + (m1-m2)**2) / (2*s2*s2) - 0.5
def montecarlo(p, q, trials=1000000):
xs = p.rvs(trials)
ps = p.pdf(xs)
qs = q.pdf(xs)
return np.mean(np.log(ps/qs))
def montecarlo2(p, q, trials=100000):
#r = norm(-3, scale=5)
r = uniform(-10, 20)
# r = expon(loc=-6)
#r = p
xs = r.rvs(trials)
rs = r.pdf(xs)
print(rs)
ps = p.pdf(xs)+0.0000001
qs = q.pdf(xs)+0.0000001
return np.mean((ps/rs)*np.log(ps/qs))
def wrong(p, q, trials=100000):
xs = np.random.uniform(-10,10, trials)
ps = p.pdf(xs)
qs = q.pdf(xs)
return np.mean(ps*np.log(ps/qs))
p = norm(loc=0, scale=1)
q = norm(loc=1, scale=1)
integral_approx = kld_approx(p, q)
analytic_solution = kl_analytic(p.kwds['loc'], p.kwds['scale'], q.kwds['loc'], q.kwds['scale'])
montecarlo_approx = montecarlo(p, q)
montecarlo_approx2 = montecarlo2(p, q)
wrong_approx = wrong(p, q)
print(f'{analytic_solution=:.10f}')
print()
print(f'{integral_approx=:.10f}')
print(f'integra error = {np.abs(analytic_solution-integral_approx):.10f}')
print()
print(f'{montecarlo_approx=:.10f}')
print(f'montecarlo error = {np.abs(analytic_solution-montecarlo_approx):.10f}')
print()
print(f'{montecarlo_approx2=:.10f}')
print(f'montecarlo2 error = {np.abs(analytic_solution-montecarlo_approx2):.10f}')
print()
print(f'{wrong_approx=:.10f}')
print(f'wrong error = {np.abs(analytic_solution-wrong_approx):.10f}')