From aeb0fcf84b9d567f4933563306365bff5659b443 Mon Sep 17 00:00:00 2001 From: Alex Moreo Date: Fri, 27 Aug 2021 13:57:26 +0200 Subject: [PATCH] adding tables generation --- MultiLabel/main.py | 18 ++++++++++++------ MultiLabel/mlevaluation.py | 6 +++--- 2 files changed, 15 insertions(+), 9 deletions(-) diff --git a/MultiLabel/main.py b/MultiLabel/main.py index 583b5a7..afb03e7 100644 --- a/MultiLabel/main.py +++ b/MultiLabel/main.py @@ -41,8 +41,8 @@ def models(): # yield 'NaivePCC', MultilabelNaiveAggregativeQuantifier(PCC(cls())) # yield 'NaiveACC', MultilabelNaiveAggregativeQuantifier(ACC(cls())) # yield 'NaivePACC', MultilabelNaiveAggregativeQuantifier(PACC(cls())) - # yield 'HDy', MultilabelNaiveAggregativeQuantifier(HDy(cls())) - # yield 'EMQ', MultilabelQuantifier(EMQ(calibratedCls())) + # yield 'NaiveHDy', MultilabelNaiveAggregativeQuantifier(HDy(cls())) + # yield 'NaiveSLD', MultilabelNaiveAggregativeQuantifier(EMQ(calibratedCls())) # yield 'StackCC', MLCC(MultilabelStackedClassifier(cls())) # yield 'StackPCC', MLPCC(MultilabelStackedClassifier(cls())) # yield 'StackACC', MLACC(MultilabelStackedClassifier(cls())) @@ -159,10 +159,14 @@ def load_results(result_path): estim_prevs = [np.vstack([estim_i, 1 - estim_i]).T for estim_i in estim_prevs] # add the constrained prevalence return true_prevs, estim_prevs results = pickle.load(open(result_path, 'rb')) - results_npp = _unpack_result_lot(results['npp']) - results_app = _unpack_result_lot(results['app']) - return results_npp, results_app - + results = { + 'npp': _unpack_result_lot(results['npp']), + 'app': _unpack_result_lot(results['app']), + } + return results + # results_npp = _unpack_result_lot(results['npp']) + # results_app = _unpack_result_lot(results['app']) + # return results_npp, results_app def run_experiment(dataset_name, model_name, model): @@ -197,3 +201,5 @@ if __name__ == '__main__': + + diff --git a/MultiLabel/mlevaluation.py b/MultiLabel/mlevaluation.py index e03b11e..71a5f33 100644 --- a/MultiLabel/mlevaluation.py +++ b/MultiLabel/mlevaluation.py @@ -8,7 +8,7 @@ import itertools from tqdm import tqdm -def __check_error(error_metric): +def check_error_str(error_metric): if isinstance(error_metric, str): error_metric = qp.error.from_name(error_metric) @@ -49,7 +49,7 @@ def ml_natural_prevalence_evaluation(model, error_metric:Union[str,Callable]='mae', random_seed=42): - error_metric = __check_error(error_metric) + error_metric = check_error_str(error_metric) true_prevs, estim_prevs = ml_natural_prevalence_prediction(model, test, sample_size, repeats, random_seed) @@ -88,7 +88,7 @@ def ml_artificial_prevalence_evaluation(model, error_metric:Union[str,Callable]='mae', random_seed=42): - error_metric = __check_error(error_metric) + error_metric = check_error_str(error_metric) true_prevs, estim_prevs = ml_artificial_prevalence_prediction(model, test, sample_size, n_prevalences, repeats, random_seed)