From 6c5bd674eadee1041726e2b5a8058dfa1137bbda Mon Sep 17 00:00:00 2001
From: Alejandro Moreo <alejandro.moreo@isti.cnr.it>
Date: Wed, 14 Feb 2024 18:54:07 +0100
Subject: [PATCH] cleaning

---
 LeQua2024/_lequa2024.py |  9 ++++-----
 LeQua2024/baselines.py  | 23 -----------------------
 2 files changed, 4 insertions(+), 28 deletions(-)

diff --git a/LeQua2024/_lequa2024.py b/LeQua2024/_lequa2024.py
index 5e414c5..285549d 100644
--- a/LeQua2024/_lequa2024.py
+++ b/LeQua2024/_lequa2024.py
@@ -1,6 +1,4 @@
-from typing import Tuple, Union
 import pandas as pd
-import numpy as np
 import os
 from os.path import join
 
@@ -63,9 +61,10 @@ def fetch_lequa2024(task, data_home='./data', merge_T3=False):
     val_true_prev_path = join(lequa_dir, task, 'public', 'dev_prevalences.txt')
     val_gen = SamplesFromDir(val_samples_path, val_true_prev_path, load_fn=load_fn)
 
-    test_samples_path = join(lequa_dir, task, 'public', 'test_samples')
-    test_true_prev_path = join(lequa_dir, task, 'public', 'test_prevalences.txt')
-    test_gen = SamplesFromDir(test_samples_path, test_true_prev_path, load_fn=load_fn)
+    # test_samples_path = join(lequa_dir, task, 'public', 'test_samples')
+    # test_true_prev_path = join(lequa_dir, task, 'public', 'test_prevalences.txt')
+    # test_gen = SamplesFromDir(test_samples_path, test_true_prev_path, load_fn=load_fn)
+    test_gen = None
 
     if task != 'T3':
         tr_path = join(lequa_dir, task, 'public', 'training_data.txt')
diff --git a/LeQua2024/baselines.py b/LeQua2024/baselines.py
index 28a19f0..ff4c33c 100644
--- a/LeQua2024/baselines.py
+++ b/LeQua2024/baselines.py
@@ -82,25 +82,10 @@ def main(args):
         else:
             quantifier.fit(train)
 
-
-        # valid_error = quantifier.best_score_
-
-        # test_err = qp.evaluation.evaluate(quantifier, protocol=gen_test, error_metric='mrae', verbose=True)
-        # print(f'method={q_name} got MRAE={test_err:.4f}')
-        #
-        # results.append((q_name, valid_error, test_err))
-
-
         print(f'saving model in {model_path}')
         pickle.dump(quantifier, open(model_path, 'wb'), protocol=pickle.HIGHEST_PROTOCOL)
 
 
-    # print('\nResults')
-    # print('Method\tValid-err\ttest-err')
-    # for q_name, valid_error, test_err in results:
-    #     print(f'{q_name}\t{valid_error:.4}\t{test_err:.4f}')
-
-
 if __name__ == '__main__':
 
     parser = argparse.ArgumentParser(description='LeQua2024 baselines')
@@ -110,12 +95,4 @@ if __name__ == '__main__':
                         help='Path of the directory containing LeQua 2024 data', default='./data')
     args = parser.parse_args()
 
-    # def assert_file(filename):
-    #     if not os.path.exists(os.path.join(args.datadir, filename)):
-    #         raise FileNotFoundError(f'path {args.datadir} does not contain "{filename}"')
-    #
-    # assert_file('dev_prevalences.txt')
-    # assert_file('training_data.txt')
-    # assert_file('dev_samples')
-
     main(args)