diff --git a/TweetSentQuant/gen_plots.py b/TweetSentQuant/gen_plots.py index 82818e8..6f19d56 100644 --- a/TweetSentQuant/gen_plots.py +++ b/TweetSentQuant/gen_plots.py @@ -12,8 +12,8 @@ from os.path import join qp.environ['SAMPLE_SIZE'] = settings.SAMPLE_SIZE plotext='png' -resultdir = './results' -plotdir = './plots' +resultdir = './results_npp' +plotdir = './plots_npp' os.makedirs(plotdir, exist_ok=True) def gather_results(methods, error_name): @@ -50,6 +50,7 @@ def plot_error_by_drift(methods, error_name, logscale=False, path=None): logscale=logscale, title=f'Quantification error as a function of distribution shift', savepath=path, + vlines=[0.02, 0.1055], method_order=method_order ) diff --git a/quapy/plot.py b/quapy/plot.py index b902257..a32ce15 100644 --- a/quapy/plot.py +++ b/quapy/plot.py @@ -176,10 +176,12 @@ def _set_colors(ax, n_methods): ax.set_prop_cycle(color=[cm(1. * i / NUM_COLORS) for i in range(NUM_COLORS)]) -def error_by_drift(method_names, true_prevs, estim_prevs, tr_prevs, n_bins=20, error_name='ae', show_std=True, +def error_by_drift(method_names, true_prevs, estim_prevs, tr_prevs, n_bins=20, error_name='ae', show_std=False, + show_density=True, logscale=False, title=f'Quantification error as a function of distribution shift', savepath=None, + vlines=None, method_order=None): fig, ax = plt.subplots() @@ -246,15 +248,20 @@ def error_by_drift(method_names, true_prevs, estim_prevs, tr_prevs, n_bins=20, e if show_std: ax.fill_between(xs, ys-ystds, ys+ystds, alpha=0.25) - ax.bar([ind * binwidth-binwidth/2 for ind in range(len(bins))], max_y*npoints/np.max(npoints), alpha=0.15, color='g', width=binwidth, label='density') + if show_density: + ax.bar([ind * binwidth-binwidth/2 for ind in range(len(bins))], + max_y*npoints/np.max(npoints), alpha=0.15, color='g', width=binwidth, label='density') ax.set(xlabel=f'Distribution shift between training set and test sample', ylabel=f'{error_name.upper()} (true distribution, predicted distribution)', title=title) box = ax.get_position() ax.set_position([box.x0, box.y0, box.width * 0.8, box.height]) - ax.axvline(0.02, 0, 1, linestyle='--', color='k') - ax.axvline(0.1055, 0, 1, linestyle='--', color='k') + if vlines: + for vline in vlines: + ax.axvline(vline, 0, 1, linestyle='--', color='k') + # ax.axvline(0.02, 0, 1, linestyle='--', color='k') + # ax.axvline(0.1055, 0, 1, linestyle='--', color='k') ax.set_xlim(0, max_x) ax.legend(loc='center left', bbox_to_anchor=(1, 0.5))