2023-05-18 22:55:10 +02:00
|
|
|
import numpy as np
|
|
|
|
import scipy.sparse as sp
|
|
|
|
from quapy.data import LabelledCollection
|
|
|
|
from quapy.method.base import BaseQuantifier
|
|
|
|
from sklearn.base import BaseEstimator
|
|
|
|
from sklearn.model_selection import cross_val_predict
|
|
|
|
|
|
|
|
from .data import ExtendedCollection
|
|
|
|
|
|
|
|
|
|
|
|
def _check_prevalence_classes(true_classes, estim_classes, estim_prev):
|
|
|
|
for _cls in true_classes:
|
|
|
|
if _cls not in estim_classes:
|
|
|
|
estim_prev = np.insert(estim_prev, _cls, [0.0], axis=0)
|
|
|
|
return estim_prev
|
|
|
|
|
2023-05-18 22:56:57 +02:00
|
|
|
|
2023-05-18 22:55:10 +02:00
|
|
|
def _get_ex_class(classes, true_class, pred_class):
|
|
|
|
return true_class * classes + pred_class
|
|
|
|
|
2023-05-18 22:56:57 +02:00
|
|
|
|
2023-05-18 22:55:10 +02:00
|
|
|
def _extend_instances(instances, pred_proba):
|
|
|
|
if isinstance(instances, sp.csr_matrix):
|
|
|
|
_pred_proba = sp.csr_matrix(pred_proba)
|
|
|
|
n_x = sp.hstack([instances, _pred_proba])
|
|
|
|
elif isinstance(instances, np.ndarray):
|
|
|
|
n_x = np.concatenate((instances, pred_proba), axis=1)
|
|
|
|
else:
|
|
|
|
raise ValueError("Unsupported matrix format")
|
|
|
|
|
|
|
|
return n_x
|
|
|
|
|
2023-05-18 22:56:57 +02:00
|
|
|
|
2023-05-18 22:55:10 +02:00
|
|
|
def _extend_collection(base: LabelledCollection, pred_proba) -> ExtendedCollection:
|
|
|
|
n_classes = base.n_classes
|
|
|
|
|
|
|
|
# n_X = [ X | predicted probs. ]
|
|
|
|
n_x = _extend_instances(base.X, pred_proba)
|
|
|
|
|
|
|
|
# n_y = (exptected y, predicted y)
|
|
|
|
pred = np.asarray([prob.argmax(axis=0) for prob in pred_proba])
|
|
|
|
n_y = np.asarray(
|
|
|
|
[
|
|
|
|
_get_ex_class(n_classes, true_class, pred_class)
|
|
|
|
for (true_class, pred_class) in zip(base.y, pred)
|
|
|
|
]
|
|
|
|
)
|
|
|
|
|
2023-05-20 20:23:17 +02:00
|
|
|
return ExtendedCollection(n_x, n_y, classes=[*range(0, n_classes * n_classes)])
|
2023-05-18 22:55:10 +02:00
|
|
|
|
|
|
|
|
|
|
|
class AccuracyEstimator:
|
|
|
|
def __init__(self, model: BaseEstimator, q_model: BaseQuantifier):
|
|
|
|
self.model = model
|
|
|
|
self.q_model = q_model
|
|
|
|
self.e_train = None
|
|
|
|
|
|
|
|
def extend(self, base: LabelledCollection, pred_proba=None) -> ExtendedCollection:
|
|
|
|
if not pred_proba:
|
|
|
|
pred_proba = self.model.predict_proba(base.X)
|
|
|
|
return _extend_collection(base, pred_proba)
|
|
|
|
|
|
|
|
def fit(self, train: LabelledCollection | ExtendedCollection):
|
|
|
|
# check if model is fit
|
|
|
|
# self.model.fit(*train.Xy)
|
|
|
|
if isinstance(train, LabelledCollection):
|
|
|
|
pred_prob_train = cross_val_predict(
|
2023-05-20 20:23:17 +02:00
|
|
|
self.model, *train.Xy, method="predict_proba"
|
2023-05-18 22:55:10 +02:00
|
|
|
)
|
|
|
|
|
|
|
|
self.e_train = _extend_collection(train, pred_prob_train)
|
|
|
|
else:
|
|
|
|
self.e_train = train
|
|
|
|
|
|
|
|
self.q_model.fit(self.e_train)
|
|
|
|
|
|
|
|
def estimate(self, instances, ext=False):
|
|
|
|
if not ext:
|
|
|
|
pred_prob = self.model.predict_proba(instances)
|
|
|
|
e_inst = _extend_instances(instances, pred_prob)
|
|
|
|
else:
|
|
|
|
e_inst = instances
|
|
|
|
|
|
|
|
estim_prev = self.q_model.quantify(e_inst)
|
|
|
|
|
|
|
|
return _check_prevalence_classes(
|
2023-05-20 20:23:17 +02:00
|
|
|
self.e_train.classes_, self.q_model.classes_, estim_prev
|
2023-05-18 22:55:10 +02:00
|
|
|
)
|