getpath moved and renamed, prevs_from_prot fixed

This commit is contained in:
Lorenzo Volpi 2024-04-05 15:54:19 +02:00
parent dcbbaba361
commit ddce8634ac
1 changed files with 8 additions and 6 deletions

View File

@ -1,15 +1,11 @@
import os
from time import time
import numpy as np
from quapy.data.base import LabelledCollection
from sklearn.base import BaseEstimator
from sklearn.metrics import confusion_matrix
def getpath(basedir, cls_name, acc_name, dataset_name, method_name):
return f"results/{basedir}/{cls_name}/{acc_name}/{dataset_name}/{method_name}.json"
def fit_method(method, V):
tinit = time()
method.fit(V)
@ -45,7 +41,13 @@ def predictionsCAPcont_table(method, test_prot, gen_acc_measure, oracle=False):
def prevs_from_prot(prot):
return [Ui.prevalence() for Ui in prot()]
def _get_plain_prev(prev: np.ndarray):
if prev.shape[0] > 2:
return tuple(prev[1:])
else:
return prev[-1]
return [_get_plain_prev(Ui.prevalence()) for Ui in prot()]
def true_acc(h: BaseEstimator, acc_fn: callable, U: LabelledCollection):