forked from moreo/QuaPy
adding density to error-by-drift plot
This commit is contained in:
parent
bdbe933a41
commit
7fd32d5c5f
|
@ -12,8 +12,8 @@ from os.path import join
|
|||
qp.environ['SAMPLE_SIZE'] = settings.SAMPLE_SIZE
|
||||
plotext='png'
|
||||
|
||||
resultdir = './results'
|
||||
plotdir = './plots'
|
||||
resultdir = './results_npp'
|
||||
plotdir = './plots_npp'
|
||||
os.makedirs(plotdir, exist_ok=True)
|
||||
|
||||
def gather_results(methods, error_name):
|
||||
|
@ -50,6 +50,7 @@ def plot_error_by_drift(methods, error_name, logscale=False, path=None):
|
|||
logscale=logscale,
|
||||
title=f'Quantification error as a function of distribution shift',
|
||||
savepath=path,
|
||||
vlines=[0.02, 0.1055],
|
||||
method_order=method_order
|
||||
)
|
||||
|
||||
|
|
|
@ -176,10 +176,12 @@ def _set_colors(ax, n_methods):
|
|||
ax.set_prop_cycle(color=[cm(1. * i / NUM_COLORS) for i in range(NUM_COLORS)])
|
||||
|
||||
|
||||
def error_by_drift(method_names, true_prevs, estim_prevs, tr_prevs, n_bins=20, error_name='ae', show_std=True,
|
||||
def error_by_drift(method_names, true_prevs, estim_prevs, tr_prevs, n_bins=20, error_name='ae', show_std=False,
|
||||
show_density=True,
|
||||
logscale=False,
|
||||
title=f'Quantification error as a function of distribution shift',
|
||||
savepath=None,
|
||||
vlines=None,
|
||||
method_order=None):
|
||||
|
||||
fig, ax = plt.subplots()
|
||||
|
@ -246,15 +248,20 @@ def error_by_drift(method_names, true_prevs, estim_prevs, tr_prevs, n_bins=20, e
|
|||
if show_std:
|
||||
ax.fill_between(xs, ys-ystds, ys+ystds, alpha=0.25)
|
||||
|
||||
ax.bar([ind * binwidth-binwidth/2 for ind in range(len(bins))], max_y*npoints/np.max(npoints), alpha=0.15, color='g', width=binwidth, label='density')
|
||||
if show_density:
|
||||
ax.bar([ind * binwidth-binwidth/2 for ind in range(len(bins))],
|
||||
max_y*npoints/np.max(npoints), alpha=0.15, color='g', width=binwidth, label='density')
|
||||
|
||||
ax.set(xlabel=f'Distribution shift between training set and test sample',
|
||||
ylabel=f'{error_name.upper()} (true distribution, predicted distribution)',
|
||||
title=title)
|
||||
box = ax.get_position()
|
||||
ax.set_position([box.x0, box.y0, box.width * 0.8, box.height])
|
||||
ax.axvline(0.02, 0, 1, linestyle='--', color='k')
|
||||
ax.axvline(0.1055, 0, 1, linestyle='--', color='k')
|
||||
if vlines:
|
||||
for vline in vlines:
|
||||
ax.axvline(vline, 0, 1, linestyle='--', color='k')
|
||||
# ax.axvline(0.02, 0, 1, linestyle='--', color='k')
|
||||
# ax.axvline(0.1055, 0, 1, linestyle='--', color='k')
|
||||
ax.set_xlim(0, max_x)
|
||||
ax.legend(loc='center left', bbox_to_anchor=(1, 0.5))
|
||||
|
||||
|
|
Loading…
Reference in New Issue