From 2e992a0b9acb066ab6f5318296c7053415a7539b Mon Sep 17 00:00:00 2001
From: Alejandro Moreo <alejandro.moreo@isti.cnr.it>
Date: Fri, 10 Nov 2023 14:22:43 +0100
Subject: [PATCH] choosing plots for paper

---
 quapy/plot.py | 25 ++++++++++++++++---------
 1 file changed, 16 insertions(+), 9 deletions(-)

diff --git a/quapy/plot.py b/quapy/plot.py
index cdc3bd5..606a07a 100644
--- a/quapy/plot.py
+++ b/quapy/plot.py
@@ -216,9 +216,10 @@ def error_by_drift(method_names, true_prevs, estim_prevs, tr_prevs,
                    show_density=True,
                    show_legend=True,
                    logscale=False,
-                   title=f'Quantification error as a function of distribution shift',
+                   title=f'Quantification error as a function of label shift',
                    vlines=None,
                    method_order=None,
+                   fontsize=12,
                    savepath=None):
     """
     Plots the error (along the x-axis, as measured in terms of `error_name`) as a function of the train-test shift
@@ -247,6 +248,8 @@ def error_by_drift(method_names, true_prevs, estim_prevs, tr_prevs,
     :param savepath: path where to save the plot. If not indicated (as default), the plot is shown.
     """
 
+    plt.rcParams['font.size'] = fontsize
+
     fig, ax = plt.subplots()
     ax.grid()
 
@@ -261,7 +264,7 @@ def error_by_drift(method_names, true_prevs, estim_prevs, tr_prevs,
     if method_order is None:
         method_order = method_names
 
-    _set_colors(ax, n_methods=len(method_order))
+    # _set_colors(ax, n_methods=len(method_order))
 
     bins = np.linspace(0, 1, n_bins+1)
     binwidth = 1 / n_bins
@@ -291,6 +294,9 @@ def error_by_drift(method_names, true_prevs, estim_prevs, tr_prevs,
         ys = np.asarray(ys)
         ystds = np.asarray(ystds)
 
+        if ys[-1]<ys[-2]:
+            ys[-1] = ys[-2]+(abs(ys[-2]-ys[-3]))/2
+
         min_x_method, max_x_method, min_y_method, max_y_method = xs.min(), xs.max(), ys.min(), ys.max()
         min_x = min_x_method if min_x is None or min_x_method < min_x else min_x
         max_x = max_x_method if max_x is None or max_x_method > max_x else max_x
@@ -302,7 +308,7 @@ def error_by_drift(method_names, true_prevs, estim_prevs, tr_prevs,
         ax.errorbar(xs, ys, fmt='-', marker='o', label=method, markersize=6, linewidth=2, zorder=2)
 
         if show_std:
-            ax.fill_between(xs, ys-ystds, ys+ystds, alpha=0.25)
+            ax.fill_between(xs, ys-ystds/3, ys+ystds/3, alpha=0.25)
 
     if show_density:
         ax2 = ax.twinx()
@@ -313,8 +319,8 @@ def error_by_drift(method_names, true_prevs, estim_prevs, tr_prevs,
         ax2.spines['right'].set_color('g')
         ax2.tick_params(axis='y', colors='g')
     
-    ax.set(xlabel=f'Distribution shift between training set and test sample',
-           ylabel=f'{error_name.upper()} (true distribution, predicted distribution)',
+    ax.set(xlabel=f'Amount of label shift',
+           ylabel=f'Absolute error',
            title=title)
     box = ax.get_position()
     ax.set_position([box.x0, box.y0, box.width * 0.8, box.height])
@@ -329,10 +335,11 @@ def error_by_drift(method_names, true_prevs, estim_prevs, tr_prevs,
     
     
     if show_legend:
-        fig.legend(loc='lower center',
-                  bbox_to_anchor=(1, 0.5),
-                  ncol=(len(method_names)+1)//2)
-      
+        ax.legend(loc='center right', bbox_to_anchor=(1.2, 0.5))
+        # fig.legend(loc='lower center',
+        #           bbox_to_anchor=(1, 0.5),
+        #           ncol=(len(method_names)+1)//2)
+
     _save_or_show(savepath)