some plots
This commit is contained in:
parent
ede214aa54
commit
faba2494b2
|
@ -1,3 +1,5 @@
|
||||||
|
import os
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from sklearn.linear_model import LogisticRegression
|
from sklearn.linear_model import LogisticRegression
|
||||||
from os.path import join
|
from os.path import join
|
||||||
|
@ -5,7 +7,7 @@ import quapy as qp
|
||||||
from quapy.protocol import UPP
|
from quapy.protocol import UPP
|
||||||
from quapy.method.aggregative import KDEyML
|
from quapy.method.aggregative import KDEyML
|
||||||
|
|
||||||
DEBUG = True
|
DEBUG = False
|
||||||
|
|
||||||
qp.environ["SAMPLE_SIZE"] = 100 if DEBUG else 500
|
qp.environ["SAMPLE_SIZE"] = 100 if DEBUG else 500
|
||||||
val_repeats = 100 if DEBUG else 500
|
val_repeats = 100 if DEBUG else 500
|
||||||
|
@ -21,7 +23,7 @@ if DEBUG:
|
||||||
bandwidth_range = np.linspace(0.01, 0.20, 10)
|
bandwidth_range = np.linspace(0.01, 0.20, 10)
|
||||||
|
|
||||||
def datasets():
|
def datasets():
|
||||||
for dataset_name in qp.datasets.UCI_MULTICLASS_DATASETS[:4]:
|
for dataset_name in qp.datasets.UCI_MULTICLASS_DATASETS:
|
||||||
dataset = qp.datasets.fetch_UCIMulticlassDataset(dataset_name)
|
dataset = qp.datasets.fetch_UCIMulticlassDataset(dataset_name)
|
||||||
if DEBUG:
|
if DEBUG:
|
||||||
dataset = dataset.reduce(random_state=0)
|
dataset = dataset.reduce(random_state=0)
|
||||||
|
@ -40,7 +42,8 @@ def experiment_dataset(dataset):
|
||||||
param_grid={'bandwidth': bandwidth_range},
|
param_grid={'bandwidth': bandwidth_range},
|
||||||
protocol=UPP(train_va, repeats=val_repeats),
|
protocol=UPP(train_va, repeats=val_repeats),
|
||||||
refit=False,
|
refit=False,
|
||||||
n_jobs=-1
|
n_jobs=-1,
|
||||||
|
verbose=True
|
||||||
).fit(train_tr)
|
).fit(train_tr)
|
||||||
chosen_bandwidth = modsel.best_params_['bandwidth']
|
chosen_bandwidth = modsel.best_params_['bandwidth']
|
||||||
modsel_choice = float(chosen_bandwidth)
|
modsel_choice = float(chosen_bandwidth)
|
||||||
|
@ -83,7 +86,10 @@ def plot_bandwidth(val_choice, test_results):
|
||||||
|
|
||||||
# Mostrar la gráfica
|
# Mostrar la gráfica
|
||||||
plt.grid(True)
|
plt.grid(True)
|
||||||
plt.show()
|
# plt.show()
|
||||||
|
os.makedirs('./plots', exist_ok=True)
|
||||||
|
plt.savefig(f'./plots/{dataset_name}.png')
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
for dataset in datasets():
|
for dataset in datasets():
|
||||||
|
|
Loading…
Reference in New Issue