From cbe3f410edceb88bcbc4d03a0d49dc6d2fe25d28 Mon Sep 17 00:00:00 2001 From: Alex Moreo Date: Fri, 20 May 2022 11:52:59 +0200 Subject: [PATCH 1/2] updating diagonal plot legend --- quapy/plot.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/quapy/plot.py b/quapy/plot.py index cdb9b1e..cafe520 100644 --- a/quapy/plot.py +++ b/quapy/plot.py @@ -7,9 +7,9 @@ from scipy.stats import ttest_ind_from_stats import quapy as qp -plt.rcParams['figure.figsize'] = [12, 8] +plt.rcParams['figure.figsize'] = [10, 6] plt.rcParams['figure.dpi'] = 200 -plt.rcParams['font.size'] = 16 +plt.rcParams['font.size'] = 18 def binary_diagonal(method_names, true_prevs, estim_prevs, pos_class=1, title=None, show_std=True, legend=True, @@ -49,9 +49,9 @@ def binary_diagonal(method_names, true_prevs, estim_prevs, pos_class=1, title=No table = {method_name:[true_prev, estim_prev] for method_name, true_prev, estim_prev in order} order = [(method_name, *table[method_name]) for method_name in method_order] - cm = plt.get_cmap('tab20') - NUM_COLORS = len(method_names) - ax.set_prop_cycle(color=[cm(1. * i / NUM_COLORS) for i in range(NUM_COLORS)]) + #cm = plt.get_cmap('tab20') + #NUM_COLORS = len(method_names) + #ax.set_prop_cycle(color=[cm(1. * i / NUM_COLORS) for i in range(NUM_COLORS)]) for method, true_prev, estim_prev in order: true_prev = true_prev[:,pos_class] estim_prev = estim_prev[:,pos_class] @@ -76,11 +76,11 @@ def binary_diagonal(method_names, true_prevs, estim_prevs, pos_class=1, title=No if legend: # box = ax.get_position() # ax.set_position([box.x0, box.y0, box.width * 0.8, box.height]) - # ax.legend(loc='center left', bbox_to_anchor=(1, 0.5)) + ax.legend(loc='center left', bbox_to_anchor=(1, 0.5)) # ax.set_position([box.x0, box.y0, box.width * 0.8, box.height]) - ax.legend(loc='lower center', - bbox_to_anchor=(1, -0.5), - ncol=(len(method_names)+1)//2) + #ax.legend(loc='lower center', + # bbox_to_anchor=(1, -0.5), + # ncol=(len(method_names)+1)//2) _save_or_show(savepath) From 140ab3bfc9a7cf398e49e601f1c1a16a9fdb8e5c Mon Sep 17 00:00:00 2001 From: Alex Moreo Date: Wed, 22 Feb 2023 11:57:22 +0100 Subject: [PATCH 2/2] adding sanity check to APP, in order to prevent the user unattendedly runs into a never-endting loop of samples being generated --- quapy/protocol.py | 16 ++++++++++++++-- quapy/tests/test_protocols.py | 13 +++++++++++++ 2 files changed, 27 insertions(+), 2 deletions(-) diff --git a/quapy/protocol.py b/quapy/protocol.py index 9361f1d..9bb716a 100644 --- a/quapy/protocol.py +++ b/quapy/protocol.py @@ -214,18 +214,30 @@ class APP(AbstractStochasticSeededProtocol, OnLabelledCollectionProtocol): :param smooth_limits_epsilon: the quantity to add and subtract to the limits 0 and 1 :param random_state: allows replicating samples across runs (default 0, meaning that the sequence of samples will be the same every time the protocol is called) + :param sanity_check: int, raises an exception warning the user that the number of examples to be generated exceed + this number; set to None for skipping this check :param return_type: set to "sample_prev" (default) to get the pairs of (sample, prevalence) at each iteration, or to "labelled_collection" to get instead instances of LabelledCollection """ - def __init__(self, data:LabelledCollection, sample_size=None, n_prevalences=21, repeats=10, - smooth_limits_epsilon=0, random_state=0, return_type='sample_prev'): + def __init__(self, data: LabelledCollection, sample_size=None, n_prevalences=21, repeats=10, + smooth_limits_epsilon=0, random_state=0, sanity_check=10000, return_type='sample_prev'): super(APP, self).__init__(random_state) self.data = data self.sample_size = qp._get_sample_size(sample_size) self.n_prevalences = n_prevalences self.repeats = repeats self.smooth_limits_epsilon = smooth_limits_epsilon + if not ((isinstance(sanity_check, int) and sanity_check>0) or sanity_check is None): + raise ValueError('param "sanity_check" must either be None or a positive integer') + if isinstance(sanity_check, int): + n = F.num_prevalence_combinations(n_prevpoints=n_prevalences, n_classes=data.n_classes, n_repeats=repeats) + if n > sanity_check: + raise RuntimeError( + f"Abort: the number of samples that will be generated by {self.__class__.__name__} ({n}) " + f"exceeds the maximum number of allowed samples ({sanity_check = }). Set 'sanity_check' to " + f"None for bypassing this check, or to a higher number.") + self.collator = OnLabelledCollectionProtocol.get_collator(return_type) def prevalence_grid(self): diff --git a/quapy/tests/test_protocols.py b/quapy/tests/test_protocols.py index 6c76d4b..87bd358 100644 --- a/quapy/tests/test_protocols.py +++ b/quapy/tests/test_protocols.py @@ -1,5 +1,7 @@ import unittest import numpy as np + +import quapy.functional from quapy.data import LabelledCollection from quapy.protocol import APP, NPP, UPP, DomainMixer, AbstractStochasticSeededProtocol @@ -19,6 +21,17 @@ def samples_to_str(protocol): class TestProtocols(unittest.TestCase): + def test_app_sanity_check(self): + data = mock_labelled_collection() + n_prevpoints = 101 + repeats = 10 + with self.assertRaises(RuntimeError): + p = APP(data, sample_size=5, n_prevalences=n_prevpoints, repeats=repeats, random_state=42) + n_combinations = \ + quapy.functional.num_prevalence_combinations(n_prevpoints, n_classes=data.n_classes, n_repeats=repeats) + p = APP(data, sample_size=5, n_prevalences=n_prevpoints, random_state=42, sanity_check=n_combinations) + p = APP(data, sample_size=5, n_prevalences=n_prevpoints, random_state=42, sanity_check=None) + def test_app_replicate(self): data = mock_labelled_collection() p = APP(data, sample_size=5, n_prevalences=11, random_state=42)