refactored collapse_false
This commit is contained in:
parent
fdde2cc20f
commit
5e32582782
|
@ -17,11 +17,10 @@ class BaseAccuracyEstimator(BaseQuantifier):
|
||||||
self,
|
self,
|
||||||
classifier: BaseEstimator,
|
classifier: BaseEstimator,
|
||||||
quantifier: BaseQuantifier,
|
quantifier: BaseQuantifier,
|
||||||
collapse_false=False,
|
|
||||||
):
|
):
|
||||||
self.__check_classifier(classifier)
|
self.__check_classifier(classifier)
|
||||||
self.quantifier = quantifier
|
self.quantifier = quantifier
|
||||||
self.extpol = ExtensionPolicy(collapse_false=collapse_false)
|
self.extpol = ExtensionPolicy()
|
||||||
|
|
||||||
def __check_classifier(self, classifier):
|
def __check_classifier(self, classifier):
|
||||||
if not hasattr(classifier, "predict_proba"):
|
if not hasattr(classifier, "predict_proba"):
|
||||||
|
@ -50,23 +49,17 @@ class BaseAccuracyEstimator(BaseQuantifier):
|
||||||
def estimate(self, instances, ext=False) -> np.ndarray:
|
def estimate(self, instances, ext=False) -> np.ndarray:
|
||||||
...
|
...
|
||||||
|
|
||||||
@property
|
|
||||||
def collapse_false(self):
|
|
||||||
return self.extpol.collapse_false
|
|
||||||
|
|
||||||
|
|
||||||
class ConfidenceBasedAccuracyEstimator(BaseAccuracyEstimator):
|
class ConfidenceBasedAccuracyEstimator(BaseAccuracyEstimator):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
classifier: BaseEstimator,
|
classifier: BaseEstimator,
|
||||||
quantifier: BaseQuantifier,
|
quantifier: BaseQuantifier,
|
||||||
collapse_false=False,
|
|
||||||
confidence=None,
|
confidence=None,
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
classifier=classifier,
|
classifier=classifier,
|
||||||
quantifier=quantifier,
|
quantifier=quantifier,
|
||||||
collapse_false=collapse_false,
|
|
||||||
)
|
)
|
||||||
self.__check_confidence(confidence)
|
self.__check_confidence(confidence)
|
||||||
self.calibrator = None
|
self.calibrator = None
|
||||||
|
@ -137,8 +130,8 @@ class MultiClassAccuracyEstimator(ConfidenceBasedAccuracyEstimator):
|
||||||
classifier=classifier,
|
classifier=classifier,
|
||||||
quantifier=quantifier,
|
quantifier=quantifier,
|
||||||
confidence=confidence,
|
confidence=confidence,
|
||||||
collapse_false=collapse_false,
|
|
||||||
)
|
)
|
||||||
|
self.extpol = ExtensionPolicy(collapse_false=collapse_false)
|
||||||
self.e_train = None
|
self.e_train = None
|
||||||
|
|
||||||
def _get_pred_ext(self, pred_proba: np.ndarray):
|
def _get_pred_ext(self, pred_proba: np.ndarray):
|
||||||
|
@ -176,6 +169,10 @@ class MultiClassAccuracyEstimator(ConfidenceBasedAccuracyEstimator):
|
||||||
estim_prev = np.insert(estim_prev, _cls, [0.0], axis=0)
|
estim_prev = np.insert(estim_prev, _cls, [0.0], axis=0)
|
||||||
return estim_prev
|
return estim_prev
|
||||||
|
|
||||||
|
@property
|
||||||
|
def collapse_false(self):
|
||||||
|
return self.extpol.collapse_false
|
||||||
|
|
||||||
|
|
||||||
class BinaryQuantifierAccuracyEstimator(ConfidenceBasedAccuracyEstimator):
|
class BinaryQuantifierAccuracyEstimator(ConfidenceBasedAccuracyEstimator):
|
||||||
def __init__(
|
def __init__(
|
||||||
|
@ -183,7 +180,6 @@ class BinaryQuantifierAccuracyEstimator(ConfidenceBasedAccuracyEstimator):
|
||||||
classifier: BaseEstimator,
|
classifier: BaseEstimator,
|
||||||
quantifier: BaseAccuracyEstimator,
|
quantifier: BaseAccuracyEstimator,
|
||||||
confidence: str = None,
|
confidence: str = None,
|
||||||
collapse_false=False,
|
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
classifier=classifier,
|
classifier=classifier,
|
||||||
|
|
Loading…
Reference in New Issue