refactored collapse_false

This commit is contained in:
Lorenzo Volpi 2023-11-26 16:32:01 +01:00
parent fdde2cc20f
commit 5e32582782
1 changed files with 6 additions and 10 deletions

View File

@ -17,11 +17,10 @@ class BaseAccuracyEstimator(BaseQuantifier):
self,
classifier: BaseEstimator,
quantifier: BaseQuantifier,
collapse_false=False,
):
self.__check_classifier(classifier)
self.quantifier = quantifier
self.extpol = ExtensionPolicy(collapse_false=collapse_false)
self.extpol = ExtensionPolicy()
def __check_classifier(self, classifier):
if not hasattr(classifier, "predict_proba"):
@ -50,23 +49,17 @@ class BaseAccuracyEstimator(BaseQuantifier):
def estimate(self, instances, ext=False) -> np.ndarray:
...
@property
def collapse_false(self):
return self.extpol.collapse_false
class ConfidenceBasedAccuracyEstimator(BaseAccuracyEstimator):
def __init__(
self,
classifier: BaseEstimator,
quantifier: BaseQuantifier,
collapse_false=False,
confidence=None,
):
super().__init__(
classifier=classifier,
quantifier=quantifier,
collapse_false=collapse_false,
)
self.__check_confidence(confidence)
self.calibrator = None
@ -137,8 +130,8 @@ class MultiClassAccuracyEstimator(ConfidenceBasedAccuracyEstimator):
classifier=classifier,
quantifier=quantifier,
confidence=confidence,
collapse_false=collapse_false,
)
self.extpol = ExtensionPolicy(collapse_false=collapse_false)
self.e_train = None
def _get_pred_ext(self, pred_proba: np.ndarray):
@ -176,6 +169,10 @@ class MultiClassAccuracyEstimator(ConfidenceBasedAccuracyEstimator):
estim_prev = np.insert(estim_prev, _cls, [0.0], axis=0)
return estim_prev
@property
def collapse_false(self):
return self.extpol.collapse_false
class BinaryQuantifierAccuracyEstimator(ConfidenceBasedAccuracyEstimator):
def __init__(
@ -183,7 +180,6 @@ class BinaryQuantifierAccuracyEstimator(ConfidenceBasedAccuracyEstimator):
classifier: BaseEstimator,
quantifier: BaseAccuracyEstimator,
confidence: str = None,
collapse_false=False,
):
super().__init__(
classifier=classifier,