import math from abc import abstractmethod from copy import deepcopy from typing import List import numpy as np from quapy.data import LabelledCollection from quapy.method.aggregative import BaseQuantifier from scipy.sparse import csr_matrix from sklearn.base import BaseEstimator from quacc.data import ExtendedCollection class BaseAccuracyEstimator(BaseQuantifier): def __init__( self, classifier: BaseEstimator, quantifier: BaseQuantifier, confidence=None, ): self.__check_classifier(classifier) self.quantifier = quantifier self.confidence = confidence def __check_classifier(self, classifier): if not hasattr(classifier, "predict_proba"): raise ValueError( f"Passed classifier {classifier.__class__.__name__} cannot predict probabilities." ) self.classifier = classifier def __get_confidence(self): def max_conf(probas): _mc = np.max(probas, axis=-1) _min = 1.0 / probas.shape[1] _norm_mc = (_mc - _min) / (1.0 - _min) return _norm_mc def entropy(probas): _ent = np.sum(np.multiply(probas, np.log(probas + 1e-20)), axis=1) return _ent if self.confidence is None: return None __confs = { "max_conf": max_conf, "entropy": entropy, } return __confs.get(self.confidence, None) def __get_ext(self, pred_proba): _ext = pred_proba _f_conf = self.__get_confidence() if _f_conf is not None: _confs = _f_conf(pred_proba).reshape((len(pred_proba), 1)) _ext = np.concatenate((_confs, pred_proba), axis=1) return _ext def extend(self, coll: LabelledCollection, pred_proba=None) -> ExtendedCollection: if pred_proba is None: pred_proba = self.classifier.predict_proba(coll.X) _ext = self.__get_ext(pred_proba) return ExtendedCollection.extend_collection(coll, pred_proba=_ext) def _extend_instances(self, instances: np.ndarray | csr_matrix, pred_proba=None): if pred_proba is None: pred_proba = self.classifier.predict_proba(instances) _ext = self.__get_ext(pred_proba) return ExtendedCollection.extend_instances(instances, _ext) @abstractmethod def fit(self, train: LabelledCollection | ExtendedCollection): ... @abstractmethod def estimate(self, instances, ext=False) -> np.ndarray: ... class MultiClassAccuracyEstimator(BaseAccuracyEstimator): def __init__( self, classifier: BaseEstimator, quantifier: BaseQuantifier, confidence: str = None, ): super().__init__( classifier=classifier, quantifier=quantifier, confidence=confidence, ) self.e_train = None def fit(self, train: LabelledCollection): self.e_train = self.extend(train) self.quantifier.fit(self.e_train) return self def estimate(self, instances, ext=False) -> np.ndarray: e_inst = instances if ext else self._extend_instances(instances) estim_prev = self.quantifier.quantify(e_inst) return self._check_prevalence_classes(estim_prev, self.quantifier.classes_) def _check_prevalence_classes(self, estim_prev, estim_classes) -> np.ndarray: true_classes = self.e_train.classes_ for _cls in true_classes: if _cls not in estim_classes: estim_prev = np.insert(estim_prev, _cls, [0.0], axis=0) return estim_prev class BinaryQuantifierAccuracyEstimator(BaseAccuracyEstimator): def __init__( self, classifier: BaseEstimator, quantifier: BaseAccuracyEstimator, confidence: str = None, ): super().__init__( classifier=classifier, quantifier=quantifier, confidence=confidence, ) self.quantifiers = [] self.e_trains = [] def fit(self, train: LabelledCollection | ExtendedCollection): self.e_train = self.extend(train) self.n_classes = self.e_train.n_classes self.e_trains = self.e_train.split_by_pred() self.quantifiers = [] for train in self.e_trains: quant = deepcopy(self.quantifier) quant.fit(train) self.quantifiers.append(quant) return self def estimate(self, instances, ext=False): # TODO: test e_inst = instances if ext else self._extend_instances(instances) _ncl = int(math.sqrt(self.n_classes)) s_inst, norms = ExtendedCollection.split_inst_by_pred(_ncl, e_inst) estim_prevs = self._quantify_helper(s_inst, norms) estim_prev = np.array([prev_row for prev_row in zip(*estim_prevs)]).flatten() return estim_prev def _quantify_helper( self, s_inst: List[np.ndarray | csr_matrix], norms: List[float], ): estim_prevs = [] for quant, inst, norm in zip(self.quantifiers, s_inst, norms): if inst.shape[0] > 0: estim_prevs.append(quant.quantify(inst) * norm) else: estim_prevs.append(np.asarray([0.0, 0.0])) return estim_prevs BAE = BaseAccuracyEstimator MCAE = MultiClassAccuracyEstimator BQAE = BinaryQuantifierAccuracyEstimator