import os import pickle import shutil import numpy as np from sklearn.linear_model import LogisticRegression from os.path import join import quapy as qp from quapy.protocol import UPP from kdey_devel import KDEyML from utils import measuretime DEBUG = True qp.environ["SAMPLE_SIZE"] = 100 if DEBUG else 500 val_repeats = 100 if DEBUG else 500 test_repeats = 100 if DEBUG else 500 if DEBUG: qp.environ["DEFAULT_CLS"] = LogisticRegression() test_results = {} val_choice = {} bandwidth_range = np.linspace(0.01, 0.20, 20) if DEBUG: bandwidth_range = np.linspace(0.01, 0.20, 5) def datasets(): dataset_list = qp.datasets.UCI_MULTICLASS_DATASETS if DEBUG: dataset_list = dataset_list[:4] for dataset_name in dataset_list: dataset = qp.datasets.fetch_UCIMulticlassDataset(dataset_name) if DEBUG: dataset = dataset.reduce(random_state=0) yield dataset @measuretime def predict_b_modsel(dataset): # bandwidth chosen during model selection in validation train = dataset.training train_tr, train_va = train.split_stratified(random_state=0) kdey = KDEyML(random_state=0) modsel = qp.model_selection.GridSearchQ( model=kdey, param_grid={'bandwidth': bandwidth_range}, protocol=UPP(train_va, repeats=val_repeats), refit=False, n_jobs=-1, verbose=True ).fit(train_tr) chosen_bandwidth = modsel.best_params_['bandwidth'] modsel_choice = float(chosen_bandwidth) # kdey.set_params(bandwidth=chosen_bandwidth) # kdey.fit(train) # kdey.qua return modsel_choice def in_test_search(dataset, n_jobs=-1): train, test = dataset.train_test print(f"testing KDEy in {dataset.name}") def experiment_job(bandwidth): kdey = KDEyML(bandwidth=bandwidth, random_state=0) kdey.fit(train) test_gen = UPP(test, repeats=test_repeats) mae = qp.evaluation.evaluate(kdey, protocol=test_gen, error_metric='mae', verbose=True) print(f'{bandwidth=}: {mae:.5f}') return float(mae) dataset_results = qp.util.parallel(experiment_job, bandwidth_range, n_jobs=n_jobs) return dataset_results, bandwidth_range def plot_bandwidth(dataset_name, test_results, bandwidths, triplet_list_results): import matplotlib.pyplot as plt print("PLOT", dataset_name) print(dataset_name) plt.figure(figsize=(8, 6)) # show test results plt.plot(bandwidths, test_results, marker='o') for (method_name, method_choice, method_time) in triplet_list_results: plt.axvline(x=method_choice, linestyle='--', label=method_name) # Agregar etiquetas y título plt.xlabel('Bandwidth') plt.ylabel('MAE') plt.title(dataset_name) # Mostrar la leyenda plt.legend(loc='center left', bbox_to_anchor=(1, 0.5)) # Mostrar la gráfica plt.grid(True) plotdir = './plots' if DEBUG: plotdir = './plots_debug' os.makedirs(plotdir, exist_ok=True) plt.tight_layout() plt.savefig(f'{plotdir}/{dataset_name}.png') plt.close() def error_table(dataset_name, test_results, bandwidth_range, triplet_list_results): best_bandwidth = bandwidth_range[np.argmin(test_results)] print(f'Method\tChoice\tAE\tTime') for method_name, method_choice, took in triplet_list_results: if method_choice in bandwidth_range: index = np.where(bandwidth_range == method_choice)[0][0] method_score = test_results[index] else: method_score = 1 error = np.abs(best_bandwidth-method_score) print(f'{method_name}\t{method_choice}\t{error}\t{took:.3}s') for dataset in datasets(): print('NAME', dataset.name) print(len(dataset.training)) print(len(dataset.test)) result_path = f'./results/{dataset.name}/' if DEBUG: result_path = result_path.replace('results', 'results_debug') if os.path.exists(result_path): shutil.rmtree(result_path) dataset_results, bandwidth_range = qp.util.pickled_resource(join(result_path, 'test.pkl'), in_test_search, dataset) triplet_list_results = [] modsel_choice, modsel_time = qp.util.pickled_resource(join(result_path, 'modsel.pkl'), predict_b_modsel, dataset) triplet_list_results.append(('modsel', modsel_choice, modsel_time,)) print(f'Dataset = {dataset.name}') print(modsel_choice) print(dataset_results) plot_bandwidth(dataset.name, dataset_results, bandwidth_range, triplet_list_results) error_table(dataset.name, dataset_results, bandwidth_range, triplet_list_results) # time_table(dataset.name, dataset_results, bandwidth_range, triplet_list_results)