getpath moved and renamed, prevs_from_prot fixed
This commit is contained in:
parent
dcbbaba361
commit
ddce8634ac
|
@ -1,15 +1,11 @@
|
||||||
import os
|
|
||||||
from time import time
|
from time import time
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
from quapy.data.base import LabelledCollection
|
from quapy.data.base import LabelledCollection
|
||||||
from sklearn.base import BaseEstimator
|
from sklearn.base import BaseEstimator
|
||||||
from sklearn.metrics import confusion_matrix
|
from sklearn.metrics import confusion_matrix
|
||||||
|
|
||||||
|
|
||||||
def getpath(basedir, cls_name, acc_name, dataset_name, method_name):
|
|
||||||
return f"results/{basedir}/{cls_name}/{acc_name}/{dataset_name}/{method_name}.json"
|
|
||||||
|
|
||||||
|
|
||||||
def fit_method(method, V):
|
def fit_method(method, V):
|
||||||
tinit = time()
|
tinit = time()
|
||||||
method.fit(V)
|
method.fit(V)
|
||||||
|
@ -45,7 +41,13 @@ def predictionsCAPcont_table(method, test_prot, gen_acc_measure, oracle=False):
|
||||||
|
|
||||||
|
|
||||||
def prevs_from_prot(prot):
|
def prevs_from_prot(prot):
|
||||||
return [Ui.prevalence() for Ui in prot()]
|
def _get_plain_prev(prev: np.ndarray):
|
||||||
|
if prev.shape[0] > 2:
|
||||||
|
return tuple(prev[1:])
|
||||||
|
else:
|
||||||
|
return prev[-1]
|
||||||
|
|
||||||
|
return [_get_plain_prev(Ui.prevalence()) for Ui in prot()]
|
||||||
|
|
||||||
|
|
||||||
def true_acc(h: BaseEstimator, acc_fn: callable, U: LabelledCollection):
|
def true_acc(h: BaseEstimator, acc_fn: callable, U: LabelledCollection):
|
||||||
|
|
Loading…
Reference in New Issue