diff --git a/quacc/method/model_selection.py b/quacc/method/model_selection.py index 2db3f67..ebc2fb8 100644 --- a/quacc/method/model_selection.py +++ b/quacc/method/model_selection.py @@ -2,6 +2,7 @@ import itertools from copy import deepcopy from time import time from typing import Callable, Union +import numpy as np import quapy as qp from quapy.data import LabelledCollection @@ -189,7 +190,7 @@ class GridSearchAE(BaseAccuracyEstimator): by the model selection process. """ - assert hasattr(self, "best_model_"), "quantify called before fit" + assert hasattr(self, "best_model_"), "estimate called before fit" return self.best_model().estimate(instances, ext=ext) def set_params(self, **parameters): @@ -219,6 +220,7 @@ class GridSearchAE(BaseAccuracyEstimator): raise ValueError("best_model called before fit") + class MCAEgsq(MultiClassAccuracyEstimator): def __init__( self, @@ -255,6 +257,11 @@ class MCAEgsq(MultiClassAccuracyEstimator): return self + def estimate(self, instances, ext=False) -> np.ndarray: + e_inst = instances if ext else self._extend_instances(instances) + estim_prev = self.quantifier.quantify(e_inst) + return self._check_prevalence_classes(estim_prev, self.quantifier.best_model().classes_) + class BQAEgsq(BinaryQuantifierAccuracyEstimator): def __init__(