refactored collapse_false

This commit is contained in:
Lorenzo Volpi 2023-11-26 16:32:01 +01:00
parent fdde2cc20f
commit 5e32582782
1 changed files with 6 additions and 10 deletions

View File

@ -17,11 +17,10 @@ class BaseAccuracyEstimator(BaseQuantifier):
self, self,
classifier: BaseEstimator, classifier: BaseEstimator,
quantifier: BaseQuantifier, quantifier: BaseQuantifier,
collapse_false=False,
): ):
self.__check_classifier(classifier) self.__check_classifier(classifier)
self.quantifier = quantifier self.quantifier = quantifier
self.extpol = ExtensionPolicy(collapse_false=collapse_false) self.extpol = ExtensionPolicy()
def __check_classifier(self, classifier): def __check_classifier(self, classifier):
if not hasattr(classifier, "predict_proba"): if not hasattr(classifier, "predict_proba"):
@ -50,23 +49,17 @@ class BaseAccuracyEstimator(BaseQuantifier):
def estimate(self, instances, ext=False) -> np.ndarray: def estimate(self, instances, ext=False) -> np.ndarray:
... ...
@property
def collapse_false(self):
return self.extpol.collapse_false
class ConfidenceBasedAccuracyEstimator(BaseAccuracyEstimator): class ConfidenceBasedAccuracyEstimator(BaseAccuracyEstimator):
def __init__( def __init__(
self, self,
classifier: BaseEstimator, classifier: BaseEstimator,
quantifier: BaseQuantifier, quantifier: BaseQuantifier,
collapse_false=False,
confidence=None, confidence=None,
): ):
super().__init__( super().__init__(
classifier=classifier, classifier=classifier,
quantifier=quantifier, quantifier=quantifier,
collapse_false=collapse_false,
) )
self.__check_confidence(confidence) self.__check_confidence(confidence)
self.calibrator = None self.calibrator = None
@ -137,8 +130,8 @@ class MultiClassAccuracyEstimator(ConfidenceBasedAccuracyEstimator):
classifier=classifier, classifier=classifier,
quantifier=quantifier, quantifier=quantifier,
confidence=confidence, confidence=confidence,
collapse_false=collapse_false,
) )
self.extpol = ExtensionPolicy(collapse_false=collapse_false)
self.e_train = None self.e_train = None
def _get_pred_ext(self, pred_proba: np.ndarray): def _get_pred_ext(self, pred_proba: np.ndarray):
@ -176,6 +169,10 @@ class MultiClassAccuracyEstimator(ConfidenceBasedAccuracyEstimator):
estim_prev = np.insert(estim_prev, _cls, [0.0], axis=0) estim_prev = np.insert(estim_prev, _cls, [0.0], axis=0)
return estim_prev return estim_prev
@property
def collapse_false(self):
return self.extpol.collapse_false
class BinaryQuantifierAccuracyEstimator(ConfidenceBasedAccuracyEstimator): class BinaryQuantifierAccuracyEstimator(ConfidenceBasedAccuracyEstimator):
def __init__( def __init__(
@ -183,7 +180,6 @@ class BinaryQuantifierAccuracyEstimator(ConfidenceBasedAccuracyEstimator):
classifier: BaseEstimator, classifier: BaseEstimator,
quantifier: BaseAccuracyEstimator, quantifier: BaseAccuracyEstimator,
confidence: str = None, confidence: str = None,
collapse_false=False,
): ):
super().__init__( super().__init__(
classifier=classifier, classifier=classifier,