From 238a30520c030a155f7910400aa73a11439ef763 Mon Sep 17 00:00:00 2001
From: Alex Moreo <alejandro.moreo@isti.cnr.it>
Date: Mon, 8 Nov 2021 18:01:49 +0100
Subject: [PATCH] adapting everything to the new format

---
 LeQua2022/baselines_T1Amodsel.py |  9 ++--
 LeQua2022/constants.py           |  5 +++
 LeQua2022/data.py                | 70 ++++++++++++++++----------------
 LeQua2022/evaluation.py          |  6 ++-
 LeQua2022/predict.py             |  6 +--
 quapy/evaluation.py              |  2 +-
 quapy/model_selection.py         | 55 ++++++++++++-------------
 7 files changed, 81 insertions(+), 72 deletions(-)

diff --git a/LeQua2022/baselines_T1Amodsel.py b/LeQua2022/baselines_T1Amodsel.py
index f291714..01b02a4 100644
--- a/LeQua2022/baselines_T1Amodsel.py
+++ b/LeQua2022/baselines_T1Amodsel.py
@@ -40,14 +40,17 @@ print(f'training matrix shape: {train.instances.shape}')
 
 true_prevalence = ResultSubmission.load(T1A_devprevalence_path)
 
-param_grid = {'C': np.logspace(-3,3,7), 'class_weight': ['balanced', None]}
+param_grid = {
+    'C': np.logspace(-3,3,7),
+    'class_weight': ['balanced', None]
+}
 
 
 def gen_samples():
     return gen_load_samples_T1(T1A_devvectors_path, nF, ground_truth_path=T1A_devprevalence_path, return_id=False)
 
 
-for quantifier in [CC, ACC, PCC, PACC, EMQ, HDy]:
+for quantifier in [CC]: #, ACC, PCC, PACC, EMQ, HDy]:
     #classifier = CalibratedClassifierCV(LogisticRegression(), n_jobs=-1)
     classifier = LogisticRegression()
     model = quantifier(classifier)
@@ -66,7 +69,7 @@ for quantifier in [CC, ACC, PCC, PACC, EMQ, HDy]:
     print(f'{quantifier_name} mae={model.best_score_:.3f} (params: {model.best_params_})')
 
     pickle.dump(model.best_model(),
-                open(os.path.join(models_path, quantifier_name+'.modsel.pkl'), 'wb'),
+                open(os.path.join(models_path, quantifier_name+'.pkl'), 'wb'),
                 protocol=pickle.HIGHEST_PROTOCOL)
 
 
diff --git a/LeQua2022/constants.py b/LeQua2022/constants.py
index dee7f8c..11a78ce 100644
--- a/LeQua2022/constants.py
+++ b/LeQua2022/constants.py
@@ -1,7 +1,12 @@
 DEV_SAMPLES = 1000
 TEST_SAMPLES = 5000
 
+TXA_SAMPLE_SIZE = 250
+TXB_SAMPLE_SIZE = 250
+
 T1A_SAMPLE_SIZE = 250
 T1B_SAMPLE_SIZE = 1000
+T2A_SAMPLE_SIZE = 250
+T2B_SAMPLE_SIZE = 1000
 
 ERROR_TOL = 1E-3
diff --git a/LeQua2022/data.py b/LeQua2022/data.py
index 596a113..bcea49f 100644
--- a/LeQua2022/data.py
+++ b/LeQua2022/data.py
@@ -80,21 +80,30 @@ def gen_load_samples_T2B(path_dir:str, ground_truth_path:str = None):
 
 class ResultSubmission:
 
-    def __init__(self, categories: List[str]):
-        if not isinstance(categories, list) or len(categories) < 2:
-            raise TypeError('wrong format for categories; a list with at least two category names (str) was expected')
-        self.categories = categories
-        self.df = pd.DataFrame(columns=list(categories))
-        self.df.index.rename('id', inplace=True)
+    def __init__(self):
+        self.df = None
+
+    def __init_df(self, categories:int):
+        if not isinstance(categories, int) or categories < 2:
+            raise TypeError('wrong format for categories: an int (>=2) was expected')
+        df = pd.DataFrame(columns=list(range(categories)))
+        df.index.set_names('id', inplace=True)
+        self.df = df
+
+    @property
+    def n_categories(self):
+        return len(self.df.columns.values)
 
     def add(self, sample_id:int, prevalence_values:np.ndarray):
         if not isinstance(sample_id, int):
             raise TypeError(f'error: expected int for sample_sample, found {type(sample_id)}')
         if not isinstance(prevalence_values, np.ndarray):
             raise TypeError(f'error: expected np.ndarray for prevalence_values, found {type(prevalence_values)}')
+        if self.df is None:
+            self.__init_df(categories=len(prevalence_values))
         if sample_id in self.df.index.values:
             raise ValueError(f'error: prevalence values for "{sample_id}" already added')
-        if prevalence_values.ndim!=1 and prevalence_values.size != len(self.categories):
+        if prevalence_values.ndim!=1 and prevalence_values.size != self.n_categories:
             raise ValueError(f'error: wrong shape found for prevalence vector {prevalence_values}')
         if (prevalence_values<0).any() or (prevalence_values>1).any():
             raise ValueError(f'error: prevalence values out of range [0,1] for "{sample_id}"')
@@ -102,9 +111,7 @@ class ResultSubmission:
             raise ValueError(f'error: prevalence values do not sum up to one for "{sample_id}"'
                              f'(error tolerance {constants.ERROR_TOL})')
 
-        # new_entry = dict([('id', sample_id)] + [(col_i, prev_i) for col_i, prev_i in enumerate(prevalence_values)])
-        new_entry = pd.DataFrame(prevalence_values.reshape(1,2), index=[sample_id], columns=self.df.columns)
-        self.df = self.df.append(new_entry, ignore_index=False)
+        self.df.loc[sample_id] = prevalence_values
 
     def __len__(self):
         return len(self.df)
@@ -112,7 +119,7 @@ class ResultSubmission:
     @classmethod
     def load(cls, path: str) -> 'ResultSubmission':
         df = ResultSubmission.check_file_format(path)
-        r = ResultSubmission(categories=df.columns.values.tolist())
+        r = ResultSubmission()
         r.df = df
         return r
 
@@ -120,16 +127,15 @@ class ResultSubmission:
         ResultSubmission.check_dataframe_format(self.df)
         self.df.to_csv(path)
 
-    def prevalence(self, sample_name:str):
-        sel = self.df.loc[self.df['filename'] == sample_name]
+    def prevalence(self, sample_id:int):
+        sel = self.df.loc[sample_id]
         if sel.empty:
             return None
         else:
-            return sel.loc[:,self.df.columns[1]:].values.flatten()
+            return sel.values.flatten()
 
     def iterrows(self):
         for index, row in self.df.iterrows():
-            # filename = row.filename
             prevalence = row.values.flatten()
             yield index, prevalence
 
@@ -146,10 +152,11 @@ class ResultSubmission:
 
         if df.index.name != 'id' or len(df.columns) < 2:
             raise ValueError(f'wrong header{hint_path}, '
-                             f'the format of the header should be "id,<cat_1>,...,<cat_n>"')
+                             f'the format of the header should be "id,0,...,n-1", '
+                             f'where n is the number of categories')
         if [int(ci) for ci in df.columns.values] != list(range(len(df.columns))):
-            raise ValueError(f'wrong header{hint_path}, category ids should be 0,1,2,...,n')
-
+            raise ValueError(f'wrong header{hint_path}, category ids should be 0,1,2,...,n-1, '
+                             f'where n is the number of categories')
         if df.empty:
             raise ValueError(f'error{hint_path}: results file is empty')
         elif len(df) != constants.DEV_SAMPLES and len(df) != constants.TEST_SAMPLES:
@@ -167,9 +174,9 @@ class ResultSubmission:
             if unexpected:
                 raise ValueError(f'there are {len(missing)} unexpected ids{hint_path}: {sorted(unexpected)}')
 
-        for category_name in df.columns:
-            if (df[category_name] < 0).any() or (df[category_name] > 1).any():
-                raise ValueError(f'error{hint_path} column "{category_name}" contains values out of range [0,1]')
+        for category_id in df.columns:
+            if (df[category_id] < 0).any() or (df[category_id] > 1).any():
+                raise ValueError(f'error{hint_path} column "{category_id}" contains values out of range [0,1]')
 
         prevs = df.values
         round_errors = np.abs(prevs.sum(axis=-1) - 1.) > constants.ERROR_TOL
@@ -180,13 +187,6 @@ class ResultSubmission:
 
         return df
 
-    def sort_categories(self):
-        self.df = self.df.reindex([self.df.columns[0]] + sorted(self.df.columns[1:]), axis=1)
-        self.categories = sorted(self.categories)
-
-    def filenames(self):
-        return self.df.filename.values
-
 
 def evaluate_submission(true_prevs: ResultSubmission, predicted_prevs: ResultSubmission, sample_size=None, average=True):
     if sample_size is None:
@@ -199,18 +199,19 @@ def evaluate_submission(true_prevs: ResultSubmission, predicted_prevs: ResultSub
     if len(true_prevs) != len(predicted_prevs):
         raise ValueError(f'size mismatch, ground truth file has {len(true_prevs)} entries '
                          f'while the file of predictions contain {len(predicted_prevs)} entries')
-    true_prevs.sort_categories()
-    predicted_prevs.sort_categories()
-    if true_prevs.categories != predicted_prevs.categories:
+    if true_prevs.n_categories != predicted_prevs.n_categories:
         raise ValueError(f'these result files are not comparable since the categories are different: '
-                         f'true={true_prevs.categories} vs. predictions={predicted_prevs.categories}')
+                         f'true={true_prevs.n_categories} categories vs. '
+                         f'predictions={predicted_prevs.n_categories} categories')
     ae, rae = [], []
-    for sample_name, true_prevalence in true_prevs.iterrows():
-        pred_prevalence = predicted_prevs.prevalence(sample_name)
+    for sample_id, true_prevalence in true_prevs.iterrows():
+        pred_prevalence = predicted_prevs.prevalence(sample_id)
         ae.append(qp.error.ae(true_prevalence, pred_prevalence))
         rae.append(qp.error.rae(true_prevalence, pred_prevalence, eps=1./(2*sample_size)))
+
     ae = np.asarray(ae)
     rae = np.asarray(rae)
+
     if average:
         return ae.mean(), rae.mean()
     else:
@@ -224,3 +225,4 @@ def evaluate_submission(true_prevs: ResultSubmission, predicted_prevs: ResultSub
 
 
 
+
diff --git a/LeQua2022/evaluation.py b/LeQua2022/evaluation.py
index e56d6d5..b52c1cb 100644
--- a/LeQua2022/evaluation.py
+++ b/LeQua2022/evaluation.py
@@ -9,8 +9,10 @@ LeQua2022 Official evaluation script
 """
 
 def main(args):
-    if args.task in {'T1A'}:
-        qp.environ['SAMPLE_SIZE'] = constants.T1A_SAMPLE_SIZE
+    if args.task in {'T1A', 'T2A'}:
+        qp.environ['SAMPLE_SIZE'] = constants.TXA_SAMPLE_SIZE
+    if args.task in {'T1B', 'T2B'}:
+        qp.environ['SAMPLE_SIZE'] = constants.TXB_SAMPLE_SIZE
     true_prev = ResultSubmission.load(args.true_prevalences)
     pred_prev = ResultSubmission.load(args.pred_prevalences)
     mae, mrae = evaluate_submission(true_prev, pred_prev)
diff --git a/LeQua2022/predict.py b/LeQua2022/predict.py
index 31d1fad..d3e1a9b 100644
--- a/LeQua2022/predict.py
+++ b/LeQua2022/predict.py
@@ -22,13 +22,13 @@ def main(args):
               f'dev samples ({constants.DEV_SAMPLES}) nor with the expected number of '
               f'test samples ({constants.TEST_SAMPLES}).')
 
-    _, categories = load_category_map(args.catmap)
+    # _, categories = load_category_map(args.catmap)
 
     # load pickled model
     model = pickle.load(open(args.model, 'rb'))
 
     # predictions
-    predictions = ResultSubmission(categories=list(range(len(categories))))
+    predictions = ResultSubmission()
     for sampleid, sample in tqdm(gen_load_samples_T1(args.samples, args.nf),
                                    desc='predicting', total=nsamples):
         predictions.add(sampleid, model.quantify(sample))
@@ -48,8 +48,6 @@ if __name__=='__main__':
                         help='Path to the directory containing the samples')
     parser.add_argument('output', metavar='PREDICTIONS-PATH', type=str,
                         help='Path where to store the predictions file')
-    parser.add_argument('catmap', metavar='CATEGORY-MAP-PATH', type=str,
-                        help='Path to the category map file')
     parser.add_argument('nf', metavar='NUM-FEATURES', type=int,
                         help='Number of features seen during training')
     args = parser.parse_args()
diff --git a/quapy/evaluation.py b/quapy/evaluation.py
index 42ecf01..ff0b356 100644
--- a/quapy/evaluation.py
+++ b/quapy/evaluation.py
@@ -81,7 +81,7 @@ def natural_prevalence_prediction(
 
 def gen_prevalence_prediction(model: BaseQuantifier, gen_fn: Callable, eval_budget=None):
     if not inspect.isgenerator(gen_fn()):
-        raise ValueError('param "gen_fun" is not a generator')
+        raise ValueError('param "gen_fun" is not a callable returning a generator')
 
     if not isinstance(eval_budget, int):
         eval_budget = -1
diff --git a/quapy/model_selection.py b/quapy/model_selection.py
index 95c6ff8..602c5e6 100644
--- a/quapy/model_selection.py
+++ b/quapy/model_selection.py
@@ -15,7 +15,7 @@ class GridSearchQ(BaseQuantifier):
     def __init__(self,
                  model: BaseQuantifier,
                  param_grid: dict,
-                 sample_size: int,
+                 sample_size: Union[int, None],
                  protocol='app',
                  n_prevpoints: int = None,
                  n_repetitions: int = 1,
@@ -32,30 +32,33 @@ class GridSearchQ(BaseQuantifier):
         protocol for quantification.
         :param model: the quantifier to optimize
         :param param_grid: a dictionary with keys the parameter names and values the list of values to explore for
-        :param sample_size: the size of the samples to extract from the validation set
-        that particular parameter
-        :param protocol: either 'app' for the artificial prevalence protocol, or 'npp' for the natural prevalence
-        protocol
-        :param n_prevpoints: if specified, indicates the number of equally distant point to extract from the interval
+        :param sample_size: the size of the samples to extract from the validation set (ignored if protocl='gen')
+        :param protocol: either 'app' for the artificial prevalence protocol, 'npp' for the natural prevalence
+        protocol, or 'gen' for using a custom sampling generator function
+        :param n_prevpoints: if specified, indicates the number of equally distant points to extract from the interval
         [0,1] in order to define the prevalences of the samples; e.g., if n_prevpoints=5, then the prevalences for
         each class will be explored in [0.00, 0.25, 0.50, 0.75, 1.00]. If not specified, then eval_budget is requested.
-        Ignored if protocol='npp'.
+        Ignored if protocol!='app'.
         :param n_repetitions: the number of repetitions for each combination of prevalences. This parameter is ignored
-        if eval_budget is set and is lower than the number of combinations that would be generated using the value
-        assigned to n_prevpoints (for the current number of classes and n_repetitions)
+        for the protocol='app' if eval_budget is set and is lower than the number of combinations that would be
+        generated using the value assigned to n_prevpoints (for the current number of classes and n_repetitions).
+        Ignored for protocol='npp' and protocol='gen' (use eval_budget for setting a maximum number of samples in
+        those cases).
         :param eval_budget: if specified, sets a ceil on the number of evaluations to perform for each hyper-parameter
-        combination. For example, if there are 3 classes, n_repetitions=1 and eval_budget=20, then n_prevpoints will be
-        set to 5, since this will generate 15 different prevalences:
-         [0, 0, 1], [0, 0.25, 0.75], [0, 0.5, 0.5] ... [1, 0, 0]
-        Ignored if protocol='npp'.
+        combination. For example, if protocol='app', there are 3 classes, n_repetitions=1 and eval_budget=20, then
+        n_prevpoints will be set to 5, since this will generate 15 different prevalences, i.e., [0, 0, 1],
+        [0, 0.25, 0.75], [0, 0.5, 0.5] ... [1, 0, 0], and since setting it to 6 would generate more than
+        20. When protocol='gen', indicates the maximum number of samples to generate, but less samples will be
+        generated if the generator yields less samples.
         :param error: an error function (callable) or a string indicating the name of an error function (valid ones
         are those in qp.error.QUANTIFICATION_ERROR
         :param refit: whether or not to refit the model on the whole labelled collection (training+validation) with
-        the best chosen hyperparameter combination
+        the best chosen hyperparameter combination. Ignored if protocol='gen'
         :param val_split: either a LabelledCollection on which to test the performance of the different settings, or
-        a float in [0,1] indicating the proportion of labelled data to extract from the training set
+        a float in [0,1] indicating the proportion of labelled data to extract from the training set, or a callable
+        returning a generator function each time it is invoked (only for protocol='gen').
         :param n_jobs: number of parallel jobs
-        :param random_seed: set the seed of the random generator to replicate experiments
+        :param random_seed: set the seed of the random generator to replicate experiments. Ignored if protocol='gen'.
         :param timeout: establishes a timer (in seconds) for each of the hyperparameters configurations being tested.
         Whenever a run takes longer than this timer, that configuration will be ignored. If all configurations end up
         being ignored, a TimeoutError exception is raised. If -1 (default) then no time bound is set.
@@ -79,17 +82,13 @@ class GridSearchQ(BaseQuantifier):
             'unknown protocol: valid ones are "app" or "npp" for the "artificial" or the "natural" prevalence ' \
             'protocols. Use protocol="gen" when passing a generator function thorough val_split that yields a ' \
             'sample (instances) and their prevalence (ndarray) at each iteration.'
-        if self.protocol == 'npp':
-            if self.n_repetitions is None or self.n_repetitions == 1:
-                if self.eval_budget is not None:
-                    print(f'[warning] when protocol=="npp" the parameter n_repetitions should be indicated '
-                          f'(and not eval_budget). Setting n_repetitions={self.eval_budget}...')
-                    self.n_repetitions = self.eval_budget
-                else:
-                    raise ValueError(f'when protocol=="npp" the parameter n_repetitions should be indicated '
-                                     f'(and should be >1).')
-            if self.n_prevpoints is not None:
-                print('[warning] n_prevpoints has been set along with the npp protocol, and will be ignored')
+        assert self.eval_budget is None or isinstance(self.eval_budget, int)
+        if self.protocol in ['npp', 'gen']:
+            if self.protocol=='npp' and (self.eval_budget is None or self.eval_budget <= 0):
+                raise ValueError(f'when protocol="npp" the parameter eval_budget should be '
+                                 f'indicated (and should be >0).')
+            if self.n_prevpoints != 1:
+                print('[warning] n_prevpoints has been set and will be ignored for the selected protocol')
 
     def sout(self, msg):
         if self.verbose:
@@ -145,7 +144,7 @@ class GridSearchQ(BaseQuantifier):
         else:
             raise ValueError('unknown protocol')
 
-    def fit(self, training: LabelledCollection, val_split: Union[LabelledCollection, float] = None):
+    def fit(self, training: LabelledCollection, val_split: Union[LabelledCollection, float, Callable] = None):
         """
         :param training: the training set on which to optimize the hyperparameters
         :param val_split: either a LabelledCollection on which to test the performance of the different settings, or