From 93dd6cb1c15eda8ef3ee6701364c11fc5bfcb5ec Mon Sep 17 00:00:00 2001
From: Lorenzo Volpi <lorenzo.volpi@outlook.com>
Date: Mon, 29 Apr 2024 17:35:43 +0200
Subject: [PATCH] training times added to globar report

---
 examples/ucimulti_experiments.py | 25 ++++++++++++++++++-------
 1 file changed, 18 insertions(+), 7 deletions(-)

diff --git a/examples/ucimulti_experiments.py b/examples/ucimulti_experiments.py
index aae8c88..5193376 100644
--- a/examples/ucimulti_experiments.py
+++ b/examples/ucimulti_experiments.py
@@ -1,5 +1,7 @@
 import pickle
 import os
+from time import time
+from collections import defaultdict
 
 import numpy as np
 from sklearn.linear_model import LogisticRegression
@@ -38,9 +40,17 @@ def show_results(result_path):
     df = pd.read_csv(result_path+'.csv', sep='\t')
     pd.set_option('display.max_columns', None)
     pd.set_option('display.max_rows', None)
-    pv = df.pivot_table(index='Dataset', columns="Method", values=["MAE", "MRAE"], margins=True)
+    pv = df.pivot_table(index='Dataset', columns="Method", values=["MAE", "MRAE", "t_train"], margins=True)
     print(pv)
 
+def load_timings(result_path):
+    import pandas as pd
+    timings = defaultdict(lambda: {})
+    if not Path(result_path + '.csv').exists():
+        return timings
+
+    df = pd.read_csv(result_path+'.csv', sep='\t')
+    return timings | df.pivot_table(index='Dataset', columns='Method', values='t_train').to_dict()
 
 if __name__ == '__main__':
 
@@ -53,8 +63,9 @@ if __name__ == '__main__':
     os.makedirs(result_dir, exist_ok=True)
 
     global_result_path = f'{result_dir}/allmethods'
+    timings = load_timings(global_result_path)
     with open(global_result_path + '.csv', 'wt') as csv:
-        csv.write(f'Method\tDataset\tMAE\tMRAE\n')
+        csv.write(f'Method\tDataset\tMAE\tMRAE\tt_train\n')
 
     for method_name, quantifier, param_grid in METHODS:
 
@@ -64,9 +75,6 @@ if __name__ == '__main__':
 
             for dataset in qp.datasets.UCI_MULTICLASS_DATASETS:
 
-                if dataset in []:
-                    continue
-
                 print('init', dataset)
 
                 local_result_path = os.path.join(Path(global_result_path).parent, method_name + '_' + dataset + '.dataframe')
@@ -88,7 +96,8 @@ if __name__ == '__main__':
                         modsel = GridSearchQ(
                             quantifier, param_grid, protocol, refit=True, n_jobs=-1, verbose=1, error='mae'
                         )
-
+                        
+                        t_init = time()
                         try:
                             modsel.fit(train)
 
@@ -99,6 +108,8 @@ if __name__ == '__main__':
                         except:
                             print('something went wrong... trying to fit the default model')
                             quantifier.fit(train)
+                        timings[method_name][dataset] = time() - t_init
+                        
 
                         protocol = UPP(test, repeats=n_bags_test)
                         report = qp.evaluation.evaluation_report(
@@ -107,7 +118,7 @@ if __name__ == '__main__':
                         report.to_csv(local_result_path)
 
                 means = report.mean(numeric_only=True)
-                csv.write(f'{method_name}\t{dataset}\t{means["mae"]:.5f}\t{means["mrae"]:.5f}\n')
+                csv.write(f'{method_name}\t{dataset}\t{means["mae"]:.5f}\t{means["mrae"]:.5f}\t{timings[method_name][dataset]:.3f}\n')
                 csv.flush()
 
     show_results(global_result_path)
\ No newline at end of file