diff --git a/quapy/benchmarking/_base.py b/quapy/benchmarking/_base.py index 227efaf..b7e4f88 100644 --- a/quapy/benchmarking/_base.py +++ b/quapy/benchmarking/_base.py @@ -49,6 +49,8 @@ class Benchmark(ABC): makedirs(join(home_dir, 'tables')) makedirs(join(home_dir, 'plots')) + self.train_prevalence = {} + def _run_id(self, method: MethodDescriptor, dataset: str): sep = Benchmark.ID_SEPARATOR assert sep not in method.id, \ @@ -104,10 +106,34 @@ class Benchmark(ABC): Table.LatexPDF(join(self.home_dir, 'tables', 'results.pdf'), list(tables.values())) - def gen_plots(self): - pass + def gen_plots(self, results, metrics=None): + import matplotlib.pyplot as plt + plt.rcParams.update({'font.size': 11}) - def show_report(self, method, dataset, report: pd.DataFrame): + if metrics is None: + metrics = ['ae'] + + for metric in metrics: + method_names, true_prevs, estim_prevs, train_prevs = [], [], [], [] + skip=False + for (method, dataset, result) in results: + method_names.append(method.name) + true_prevs.append(np.vstack(result['true-prev'].values)) + estim_prevs.append(np.vstack(result['estim-prev'].values)) + train_prevs.append(self.get_training_prevalence(dataset)) + if not skip: + path = join(self.home_dir, 'plots', f'err_by_drift_{metric}.pdf') + qp.plot.error_by_drift(method_names, true_prevs, estim_prevs, train_prevs, error_name=metric, n_bins=20, savepath=path) + + + + + + + + + + def _show_report(self, method, dataset, report: pd.DataFrame): id = method.id MAE = report['mae'].mean() mae_std = report['mae'].std() @@ -146,19 +172,20 @@ class Benchmark(ABC): seed=0, asarray=False ) - results += [(method, dataset, result) for (method, dataset), result in zip(pending_job_args, remaining_results)] + results += [ + (method, dataset, result) for (method, dataset), result in zip(pending_job_args, remaining_results) + ] # print results for method, dataset, result in results: - self.show_report(method, dataset, result) + self._show_report(method, dataset, result) self.gen_tables(results) - self.gen_plots() - - # def gen_plots(self, methods=None): - # if methods is None: - + self.gen_plots(results) + @abstractmethod + def get_training_prevalence(self, dataset: str): + ... def __add__(self, other: 'Benchmark'): return CombinedBenchmark(self, other, self.n_jobs) @@ -192,6 +219,10 @@ class TypicalBenchmark(Benchmark): def get_sample_size(self)-> int: ... + @abstractmethod + def get_training(self, dataset:str)-> LabelledCollection: + ... + @abstractmethod def get_trModsel_valprotModsel_trEval_teprotEval(self, dataset:str)->\ (LabelledCollection, AbstractProtocol, LabelledCollection, AbstractProtocol): @@ -212,7 +243,8 @@ class TypicalBenchmark(Benchmark): with qp.util.temp_seed(random_state): # data split - trModSel, valprotModSel, trEval, teprotEval = self.get_trModsel_valprotModsel_trEval_teprotEval(dataset) + trModSel, valprotModSel, trEval, teprotEval = self.get_trModsel_valprotModsel_trEval_teprotEval(dataset) + self.train_prevalence[dataset] = trEval.prevalence() # model selection modsel = GridSearchQ( @@ -247,6 +279,12 @@ class TypicalBenchmark(Benchmark): return report + def get_training_prevalence(self, dataset: str): + if not dataset in self.train_prevalence: + training = self.get_training(dataset) + self.train_prevalence[dataset] = training.prevalence() + return self.train_prevalence[dataset] + class UCIBinaryBenchmark(TypicalBenchmark): @@ -259,6 +297,9 @@ class UCIBinaryBenchmark(TypicalBenchmark): testprotModsel = APP(teEval, n_prevalences=21, repeats=100) return trModsel, valprotModsel, trEval, testprotModsel + def get_training(self, dataset:str) -> LabelledCollection: + return qp.datasets.fetch_UCIBinaryDataset(dataset).training + def get_sample_size(self) -> int: return 100 @@ -284,6 +325,9 @@ class UCIMultiBenchmark(TypicalBenchmark): testprotModsel = UPP(teEval, repeats=1000) return trModsel, valprotModsel, trEval, testprotModsel + def get_training(self, dataset:str) -> LabelledCollection: + return qp.datasets.fetch_UCIMulticlassDataset(dataset).training + def get_sample_size(self) -> int: return 500 @@ -291,6 +335,7 @@ class UCIMultiBenchmark(TypicalBenchmark): return 'mae' + if __name__ == '__main__': from quapy.benchmarking.typical import * diff --git a/quapy/plot.py b/quapy/plot.py index cdc3bd5..752ec4a 100644 --- a/quapy/plot.py +++ b/quapy/plot.py @@ -259,7 +259,7 @@ def error_by_drift(method_names, true_prevs, estim_prevs, tr_prevs, data = _join_data_by_drift(method_names, true_prevs, estim_prevs, tr_prevs, x_error, y_error, method_order) if method_order is None: - method_order = method_names + method_order = np.unique(method_names) _set_colors(ax, n_methods=len(method_order)) @@ -329,10 +329,9 @@ def error_by_drift(method_names, true_prevs, estim_prevs, tr_prevs, if show_legend: - fig.legend(loc='lower center', - bbox_to_anchor=(1, 0.5), - ncol=(len(method_names)+1)//2) - + # fig.legend(loc='upper center', bbox_to_anchor=(0.5, -0.05), ncol=(len(method_order)//2)+1) + fig.legend(loc='upper right', bbox_to_anchor=(1, 0.6)) + _save_or_show(savepath)