adding density to error-by-drift plot
This commit is contained in:
parent
bdbe933a41
commit
7fd32d5c5f
|
@ -12,8 +12,8 @@ from os.path import join
|
||||||
qp.environ['SAMPLE_SIZE'] = settings.SAMPLE_SIZE
|
qp.environ['SAMPLE_SIZE'] = settings.SAMPLE_SIZE
|
||||||
plotext='png'
|
plotext='png'
|
||||||
|
|
||||||
resultdir = './results'
|
resultdir = './results_npp'
|
||||||
plotdir = './plots'
|
plotdir = './plots_npp'
|
||||||
os.makedirs(plotdir, exist_ok=True)
|
os.makedirs(plotdir, exist_ok=True)
|
||||||
|
|
||||||
def gather_results(methods, error_name):
|
def gather_results(methods, error_name):
|
||||||
|
@ -50,6 +50,7 @@ def plot_error_by_drift(methods, error_name, logscale=False, path=None):
|
||||||
logscale=logscale,
|
logscale=logscale,
|
||||||
title=f'Quantification error as a function of distribution shift',
|
title=f'Quantification error as a function of distribution shift',
|
||||||
savepath=path,
|
savepath=path,
|
||||||
|
vlines=[0.02, 0.1055],
|
||||||
method_order=method_order
|
method_order=method_order
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -176,10 +176,12 @@ def _set_colors(ax, n_methods):
|
||||||
ax.set_prop_cycle(color=[cm(1. * i / NUM_COLORS) for i in range(NUM_COLORS)])
|
ax.set_prop_cycle(color=[cm(1. * i / NUM_COLORS) for i in range(NUM_COLORS)])
|
||||||
|
|
||||||
|
|
||||||
def error_by_drift(method_names, true_prevs, estim_prevs, tr_prevs, n_bins=20, error_name='ae', show_std=True,
|
def error_by_drift(method_names, true_prevs, estim_prevs, tr_prevs, n_bins=20, error_name='ae', show_std=False,
|
||||||
|
show_density=True,
|
||||||
logscale=False,
|
logscale=False,
|
||||||
title=f'Quantification error as a function of distribution shift',
|
title=f'Quantification error as a function of distribution shift',
|
||||||
savepath=None,
|
savepath=None,
|
||||||
|
vlines=None,
|
||||||
method_order=None):
|
method_order=None):
|
||||||
|
|
||||||
fig, ax = plt.subplots()
|
fig, ax = plt.subplots()
|
||||||
|
@ -246,15 +248,20 @@ def error_by_drift(method_names, true_prevs, estim_prevs, tr_prevs, n_bins=20, e
|
||||||
if show_std:
|
if show_std:
|
||||||
ax.fill_between(xs, ys-ystds, ys+ystds, alpha=0.25)
|
ax.fill_between(xs, ys-ystds, ys+ystds, alpha=0.25)
|
||||||
|
|
||||||
ax.bar([ind * binwidth-binwidth/2 for ind in range(len(bins))], max_y*npoints/np.max(npoints), alpha=0.15, color='g', width=binwidth, label='density')
|
if show_density:
|
||||||
|
ax.bar([ind * binwidth-binwidth/2 for ind in range(len(bins))],
|
||||||
|
max_y*npoints/np.max(npoints), alpha=0.15, color='g', width=binwidth, label='density')
|
||||||
|
|
||||||
ax.set(xlabel=f'Distribution shift between training set and test sample',
|
ax.set(xlabel=f'Distribution shift between training set and test sample',
|
||||||
ylabel=f'{error_name.upper()} (true distribution, predicted distribution)',
|
ylabel=f'{error_name.upper()} (true distribution, predicted distribution)',
|
||||||
title=title)
|
title=title)
|
||||||
box = ax.get_position()
|
box = ax.get_position()
|
||||||
ax.set_position([box.x0, box.y0, box.width * 0.8, box.height])
|
ax.set_position([box.x0, box.y0, box.width * 0.8, box.height])
|
||||||
ax.axvline(0.02, 0, 1, linestyle='--', color='k')
|
if vlines:
|
||||||
ax.axvline(0.1055, 0, 1, linestyle='--', color='k')
|
for vline in vlines:
|
||||||
|
ax.axvline(vline, 0, 1, linestyle='--', color='k')
|
||||||
|
# ax.axvline(0.02, 0, 1, linestyle='--', color='k')
|
||||||
|
# ax.axvline(0.1055, 0, 1, linestyle='--', color='k')
|
||||||
ax.set_xlim(0, max_x)
|
ax.set_xlim(0, max_x)
|
||||||
ax.legend(loc='center left', bbox_to_anchor=(1, 0.5))
|
ax.legend(loc='center left', bbox_to_anchor=(1, 0.5))
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue