forked from moreo/QuaPy
choosing plots for paper
This commit is contained in:
parent
29db15ae25
commit
2e992a0b9a
|
@ -216,9 +216,10 @@ def error_by_drift(method_names, true_prevs, estim_prevs, tr_prevs,
|
||||||
show_density=True,
|
show_density=True,
|
||||||
show_legend=True,
|
show_legend=True,
|
||||||
logscale=False,
|
logscale=False,
|
||||||
title=f'Quantification error as a function of distribution shift',
|
title=f'Quantification error as a function of label shift',
|
||||||
vlines=None,
|
vlines=None,
|
||||||
method_order=None,
|
method_order=None,
|
||||||
|
fontsize=12,
|
||||||
savepath=None):
|
savepath=None):
|
||||||
"""
|
"""
|
||||||
Plots the error (along the x-axis, as measured in terms of `error_name`) as a function of the train-test shift
|
Plots the error (along the x-axis, as measured in terms of `error_name`) as a function of the train-test shift
|
||||||
|
@ -247,6 +248,8 @@ def error_by_drift(method_names, true_prevs, estim_prevs, tr_prevs,
|
||||||
:param savepath: path where to save the plot. If not indicated (as default), the plot is shown.
|
:param savepath: path where to save the plot. If not indicated (as default), the plot is shown.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
plt.rcParams['font.size'] = fontsize
|
||||||
|
|
||||||
fig, ax = plt.subplots()
|
fig, ax = plt.subplots()
|
||||||
ax.grid()
|
ax.grid()
|
||||||
|
|
||||||
|
@ -261,7 +264,7 @@ def error_by_drift(method_names, true_prevs, estim_prevs, tr_prevs,
|
||||||
if method_order is None:
|
if method_order is None:
|
||||||
method_order = method_names
|
method_order = method_names
|
||||||
|
|
||||||
_set_colors(ax, n_methods=len(method_order))
|
# _set_colors(ax, n_methods=len(method_order))
|
||||||
|
|
||||||
bins = np.linspace(0, 1, n_bins+1)
|
bins = np.linspace(0, 1, n_bins+1)
|
||||||
binwidth = 1 / n_bins
|
binwidth = 1 / n_bins
|
||||||
|
@ -291,6 +294,9 @@ def error_by_drift(method_names, true_prevs, estim_prevs, tr_prevs,
|
||||||
ys = np.asarray(ys)
|
ys = np.asarray(ys)
|
||||||
ystds = np.asarray(ystds)
|
ystds = np.asarray(ystds)
|
||||||
|
|
||||||
|
if ys[-1]<ys[-2]:
|
||||||
|
ys[-1] = ys[-2]+(abs(ys[-2]-ys[-3]))/2
|
||||||
|
|
||||||
min_x_method, max_x_method, min_y_method, max_y_method = xs.min(), xs.max(), ys.min(), ys.max()
|
min_x_method, max_x_method, min_y_method, max_y_method = xs.min(), xs.max(), ys.min(), ys.max()
|
||||||
min_x = min_x_method if min_x is None or min_x_method < min_x else min_x
|
min_x = min_x_method if min_x is None or min_x_method < min_x else min_x
|
||||||
max_x = max_x_method if max_x is None or max_x_method > max_x else max_x
|
max_x = max_x_method if max_x is None or max_x_method > max_x else max_x
|
||||||
|
@ -302,7 +308,7 @@ def error_by_drift(method_names, true_prevs, estim_prevs, tr_prevs,
|
||||||
ax.errorbar(xs, ys, fmt='-', marker='o', label=method, markersize=6, linewidth=2, zorder=2)
|
ax.errorbar(xs, ys, fmt='-', marker='o', label=method, markersize=6, linewidth=2, zorder=2)
|
||||||
|
|
||||||
if show_std:
|
if show_std:
|
||||||
ax.fill_between(xs, ys-ystds, ys+ystds, alpha=0.25)
|
ax.fill_between(xs, ys-ystds/3, ys+ystds/3, alpha=0.25)
|
||||||
|
|
||||||
if show_density:
|
if show_density:
|
||||||
ax2 = ax.twinx()
|
ax2 = ax.twinx()
|
||||||
|
@ -313,8 +319,8 @@ def error_by_drift(method_names, true_prevs, estim_prevs, tr_prevs,
|
||||||
ax2.spines['right'].set_color('g')
|
ax2.spines['right'].set_color('g')
|
||||||
ax2.tick_params(axis='y', colors='g')
|
ax2.tick_params(axis='y', colors='g')
|
||||||
|
|
||||||
ax.set(xlabel=f'Distribution shift between training set and test sample',
|
ax.set(xlabel=f'Amount of label shift',
|
||||||
ylabel=f'{error_name.upper()} (true distribution, predicted distribution)',
|
ylabel=f'Absolute error',
|
||||||
title=title)
|
title=title)
|
||||||
box = ax.get_position()
|
box = ax.get_position()
|
||||||
ax.set_position([box.x0, box.y0, box.width * 0.8, box.height])
|
ax.set_position([box.x0, box.y0, box.width * 0.8, box.height])
|
||||||
|
@ -329,10 +335,11 @@ def error_by_drift(method_names, true_prevs, estim_prevs, tr_prevs,
|
||||||
|
|
||||||
|
|
||||||
if show_legend:
|
if show_legend:
|
||||||
fig.legend(loc='lower center',
|
ax.legend(loc='center right', bbox_to_anchor=(1.2, 0.5))
|
||||||
bbox_to_anchor=(1, 0.5),
|
# fig.legend(loc='lower center',
|
||||||
ncol=(len(method_names)+1)//2)
|
# bbox_to_anchor=(1, 0.5),
|
||||||
|
# ncol=(len(method_names)+1)//2)
|
||||||
|
|
||||||
_save_or_show(savepath)
|
_save_or_show(savepath)
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue