diff --git a/quacc/experiments/util.py b/quacc/experiments/util.py index d2af476..6d35572 100644 --- a/quacc/experiments/util.py +++ b/quacc/experiments/util.py @@ -1,15 +1,11 @@ -import os from time import time +import numpy as np from quapy.data.base import LabelledCollection from sklearn.base import BaseEstimator from sklearn.metrics import confusion_matrix -def getpath(basedir, cls_name, acc_name, dataset_name, method_name): - return f"results/{basedir}/{cls_name}/{acc_name}/{dataset_name}/{method_name}.json" - - def fit_method(method, V): tinit = time() method.fit(V) @@ -45,7 +41,13 @@ def predictionsCAPcont_table(method, test_prot, gen_acc_measure, oracle=False): def prevs_from_prot(prot): - return [Ui.prevalence() for Ui in prot()] + def _get_plain_prev(prev: np.ndarray): + if prev.shape[0] > 2: + return tuple(prev[1:]) + else: + return prev[-1] + + return [_get_plain_prev(Ui.prevalence()) for Ui in prot()] def true_acc(h: BaseEstimator, acc_fn: callable, U: LabelledCollection):