main updated
This commit is contained in:
parent
da827943d6
commit
b969234244
|
@ -2,7 +2,7 @@ import pandas as pd
|
|||
import quapy as qp
|
||||
from quapy.method.aggregative import SLD
|
||||
from quapy.protocol import APP
|
||||
from sklearn.linear_model import LogisticRegression
|
||||
from sklearn.svm import SVC
|
||||
|
||||
import quacc.evaluation as eval
|
||||
from quacc.estimator import AccuracyEstimator
|
||||
|
@ -17,23 +17,24 @@ pd.set_option("display.float_format", "{:.4f}".format)
|
|||
def test_2(dataset_name):
|
||||
train, test = get_dataset(dataset_name)
|
||||
|
||||
model = LogisticRegression()
|
||||
model = SVC(probability=True)
|
||||
|
||||
print(f"fitting model {model.__class__.__name__}...", end=" ")
|
||||
print(f"fitting model {model.__class__.__name__}...", end=" ", flush=True)
|
||||
model.fit(*train.Xy)
|
||||
print("fit")
|
||||
|
||||
qmodel = SLD(LogisticRegression())
|
||||
qmodel = SLD(SVC(probability=True))
|
||||
estimator = AccuracyEstimator(model, qmodel)
|
||||
|
||||
print(f"fitting qmodel {qmodel.__class__.__name__}...", end=" ")
|
||||
print(f"fitting qmodel {qmodel.__class__.__name__}...", end=" ", flush=True)
|
||||
estimator.fit(train)
|
||||
print("fit")
|
||||
|
||||
n_prevalences = 21
|
||||
repreats = 1000
|
||||
protocol = APP(test, n_prevalences=n_prevalences, repeats=repreats)
|
||||
print( f"Tests:\n\
|
||||
print(
|
||||
f"Tests:\n\
|
||||
protocol={protocol.__class__.__name__}\n\
|
||||
n_prevalences={n_prevalences}\n\
|
||||
repreats={repreats}\n\
|
||||
|
@ -49,9 +50,9 @@ def test_2(dataset_name):
|
|||
|
||||
def main():
|
||||
for dataset_name in [
|
||||
"hp",
|
||||
"imdb",
|
||||
"spambase",
|
||||
# "hp",
|
||||
# "spambase",
|
||||
]:
|
||||
print(dataset_name)
|
||||
test_2(dataset_name)
|
||||
|
|
Loading…
Reference in New Issue