diff --git a/quapy/plot.py b/quapy/plot.py index 7b105e2..5e51640 100644 --- a/quapy/plot.py +++ b/quapy/plot.py @@ -12,6 +12,11 @@ plt.rcParams['figure.dpi'] = 200 plt.rcParams['font.size'] = 16 +def _set_colors(ax, n_methods): + NUM_COLORS = n_methods + cm = plt.get_cmap('gist_rainbow') + ax.set_color_cycle([cm(1. * i / NUM_COLORS) for i in range(NUM_COLORS)]) + def binary_diagonal(method_names, true_prevs, estim_prevs, pos_class=1, title=None, show_std=True, legend=True, train_prev=None, savepath=None): fig, ax = plt.subplots() @@ -20,6 +25,7 @@ def binary_diagonal(method_names, true_prevs, estim_prevs, pos_class=1, title=No ax.plot([0, 1], [0, 1], '--k', label='ideal', zorder=1) method_names, true_prevs, estim_prevs = _merge(method_names, true_prevs, estim_prevs) + _set_colors(ax, n_methods=len(method_names)) for method, true_prev, estim_prev in zip(method_names, true_prevs, estim_prevs): true_prev = true_prev[:,pos_class] @@ -78,6 +84,7 @@ def binary_bias_bins(method_names, true_prevs, estim_prevs, pos_class=1, title=N ax.grid() method_names, true_prevs, estim_prevs = _merge(method_names, true_prevs, estim_prevs) + _set_colors(ax, n_methods=len(method_names)) bins = np.linspace(0, 1, nbins+1) binwidth = 1/nbins @@ -185,6 +192,8 @@ def error_by_drift(method_names, true_prevs, estim_prevs, tr_prevs, n_bins=20, e if method not in method_order: method_order.append(method) + _set_colors(ax, n_methods=len(method_order)) + bins = np.linspace(0, 1, n_bins+1) inds_histogram_global = np.zeros(n_bins, dtype=np.float) # we use this to keep track of how many datapoits contribute to each bin binwidth = 1 / n_bins