QuAcc/quacc/models/base.py

145 lines
5.1 KiB
Python

from abc import ABC, abstractmethod
from copy import deepcopy
import numpy as np
import quapy as qp
import quapy.functional as F
from quapy.protocol import UPP
from sklearn.base import BaseEstimator
from sklearn.metrics import confusion_matrix
from quacc.legacy.data import LabelledCollection
class ClassifierAccuracyPrediction(ABC):
def __init__(self, h: BaseEstimator, acc: callable):
self.h = h
self.acc = acc
@abstractmethod
def fit(self, val: LabelledCollection): ...
@abstractmethod
def predict(self, X, oracle_prev=None):
"""
Evaluates the accuracy function on the predicted contingency table
:param X: test data
:param oracle_prev: np.ndarray with the class prevalence of the test set as estimated by
an oracle. This is meant to test the effect of the errors in CAP that are explained by
the errors in quantification performance
:return: float
"""
return ...
def true_acc(self, sample: LabelledCollection):
y_pred = self.h.predict(sample.X)
y_true = sample.y
conf_table = confusion_matrix(y_true, y_pred=y_pred, labels=sample.classes_)
return self.acc(conf_table)
class SebastianiCAP(ClassifierAccuracyPrediction):
def __init__(
self, h, acc_fn, q_class, n_val_samples=500, alpha=0.3, predict_train_prev=True
):
self.h = h
self.acc = acc_fn
self.q = q_class(h)
self.n_val_samples = n_val_samples
self.alpha = alpha
self.sample_size = qp.environ["SAMPLE_SIZE"]
self.predict_train_prev = predict_train_prev
def fit(self, val: LabelledCollection):
v2, v1 = val.split_stratified(train_prop=0.5)
self.q.fit(v1, fit_classifier=False, val_split=v1)
# precompute classifier predictions on samples
gen_samples = UPP(
v2,
repeats=self.n_val_samples,
sample_size=self.sample_size,
return_type="labelled_collection",
)
self.sigma_acc = [self.true_acc(sigma_i) for sigma_i in gen_samples()]
# precompute prevalence predictions on samples
if self.predict_train_prev:
gen_samples.on_preclassified_instances(self.q.classify(v2.X), in_place=True)
self.sigma_pred_prevs = [
self.q.aggregate(sigma_i.X) for sigma_i in gen_samples()
]
else:
self.sigma_pred_prevs = [sigma_i.prevalence() for sigma_i in gen_samples()]
def predict(self, X, oracle_prev=None):
if oracle_prev is None:
test_pred_prev = self.q.quantify(X)
else:
test_pred_prev = oracle_prev
if self.alpha > 0:
# select samples from V2 with predicted prevalence close to the predicted prevalence for U
selected_accuracies = []
for pred_prev_i, acc_i in zip(self.sigma_pred_prevs, self.sigma_acc):
max_discrepancy = np.max(np.abs(pred_prev_i - test_pred_prev))
if max_discrepancy < self.alpha:
selected_accuracies.append(acc_i)
return np.median(selected_accuracies)
else:
# mean average, weights samples from V2 according to the closeness of predicted prevalence in U
accum_weight = 0
moving_mean = 0
epsilon = 10e-4
for pred_prev_i, acc_i in zip(self.sigma_pred_prevs, self.sigma_acc):
max_discrepancy = np.max(np.abs(pred_prev_i - test_pred_prev))
weight = -np.log(max_discrepancy + epsilon)
accum_weight += weight
moving_mean += weight * acc_i
return moving_mean / accum_weight
class PabloCAP(ClassifierAccuracyPrediction):
def __init__(self, h, acc_fn, q_class, n_val_samples=100, aggr="mean"):
self.h = h
self.acc = acc_fn
self.q = q_class(deepcopy(h))
self.n_val_samples = n_val_samples
self.aggr = aggr
assert aggr in [
"mean",
"median",
], "unknown aggregation function, use mean or median"
def fit(self, val: LabelledCollection):
self.q.fit(val)
label_predictions = self.h.predict(val.X)
self.pre_classified = LabelledCollection(
instances=label_predictions, labels=val.labels
)
def predict(self, X, oracle_prev=None):
if oracle_prev is None:
pred_prev = F.smooth(self.q.quantify(X))
else:
pred_prev = oracle_prev
X_size = X.shape[0]
acc_estim = []
for _ in range(self.n_val_samples):
sigma_i = self.pre_classified.sampling(X_size, *pred_prev[:-1])
y_pred, y_true = sigma_i.Xy
conf_table = confusion_matrix(
y_true, y_pred=y_pred, labels=sigma_i.classes_
)
acc_i = self.acc(conf_table)
acc_estim.append(acc_i)
if self.aggr == "mean":
return np.mean(acc_estim)
elif self.aggr == "median":
return np.median(acc_estim)
else:
raise ValueError("unknown aggregation function")