refactored collapse_false
This commit is contained in:
parent
fdde2cc20f
commit
5e32582782
|
@ -17,11 +17,10 @@ class BaseAccuracyEstimator(BaseQuantifier):
|
|||
self,
|
||||
classifier: BaseEstimator,
|
||||
quantifier: BaseQuantifier,
|
||||
collapse_false=False,
|
||||
):
|
||||
self.__check_classifier(classifier)
|
||||
self.quantifier = quantifier
|
||||
self.extpol = ExtensionPolicy(collapse_false=collapse_false)
|
||||
self.extpol = ExtensionPolicy()
|
||||
|
||||
def __check_classifier(self, classifier):
|
||||
if not hasattr(classifier, "predict_proba"):
|
||||
|
@ -50,23 +49,17 @@ class BaseAccuracyEstimator(BaseQuantifier):
|
|||
def estimate(self, instances, ext=False) -> np.ndarray:
|
||||
...
|
||||
|
||||
@property
|
||||
def collapse_false(self):
|
||||
return self.extpol.collapse_false
|
||||
|
||||
|
||||
class ConfidenceBasedAccuracyEstimator(BaseAccuracyEstimator):
|
||||
def __init__(
|
||||
self,
|
||||
classifier: BaseEstimator,
|
||||
quantifier: BaseQuantifier,
|
||||
collapse_false=False,
|
||||
confidence=None,
|
||||
):
|
||||
super().__init__(
|
||||
classifier=classifier,
|
||||
quantifier=quantifier,
|
||||
collapse_false=collapse_false,
|
||||
)
|
||||
self.__check_confidence(confidence)
|
||||
self.calibrator = None
|
||||
|
@ -137,8 +130,8 @@ class MultiClassAccuracyEstimator(ConfidenceBasedAccuracyEstimator):
|
|||
classifier=classifier,
|
||||
quantifier=quantifier,
|
||||
confidence=confidence,
|
||||
collapse_false=collapse_false,
|
||||
)
|
||||
self.extpol = ExtensionPolicy(collapse_false=collapse_false)
|
||||
self.e_train = None
|
||||
|
||||
def _get_pred_ext(self, pred_proba: np.ndarray):
|
||||
|
@ -176,6 +169,10 @@ class MultiClassAccuracyEstimator(ConfidenceBasedAccuracyEstimator):
|
|||
estim_prev = np.insert(estim_prev, _cls, [0.0], axis=0)
|
||||
return estim_prev
|
||||
|
||||
@property
|
||||
def collapse_false(self):
|
||||
return self.extpol.collapse_false
|
||||
|
||||
|
||||
class BinaryQuantifierAccuracyEstimator(ConfidenceBasedAccuracyEstimator):
|
||||
def __init__(
|
||||
|
@ -183,7 +180,6 @@ class BinaryQuantifierAccuracyEstimator(ConfidenceBasedAccuracyEstimator):
|
|||
classifier: BaseEstimator,
|
||||
quantifier: BaseAccuracyEstimator,
|
||||
confidence: str = None,
|
||||
collapse_false=False,
|
||||
):
|
||||
super().__init__(
|
||||
classifier=classifier,
|
||||
|
|
Loading…
Reference in New Issue