Refactoring completed

This commit is contained in:
Lorenzo Volpi 2023-05-20 20:23:17 +02:00
parent 6ac18137fa
commit 5959a0d323
5 changed files with 84 additions and 10 deletions

View File

@ -7,7 +7,6 @@ from typing import List, Optional
class ExtendedCollection(LabelledCollection): class ExtendedCollection(LabelledCollection):
def __init__( def __init__(
self, self,
b_coll: LabelledCollection,
instances: np.ndarray | sp.csr_matrix, instances: np.ndarray | sp.csr_matrix,
labels: np.ndarray, labels: np.ndarray,
classes: Optional[List] = None, classes: Optional[List] = None,

15
quacc/error.py Normal file
View File

@ -0,0 +1,15 @@
import quapy as qp
def from_name(err_name):
if err_name == 'f1e':
return f1e
else:
return qp.error.from_name(err_name)
def f1e(prev):
return 1 - f1_score(prev)
def f1_score(prev):
recall = prev[0] / (prev[0] + prev[1])
precision = prev[0] / (prev[0] + prev[2])
return 2 * (precision * recall) / (precision + recall)

View File

@ -46,7 +46,7 @@ def _extend_collection(base: LabelledCollection, pred_proba) -> ExtendedCollecti
] ]
) )
return ExtendedCollection(n_x, n_y, [*range(0, n_classes * n_classes)]) return ExtendedCollection(n_x, n_y, classes=[*range(0, n_classes * n_classes)])
class AccuracyEstimator: class AccuracyEstimator:
@ -65,7 +65,7 @@ class AccuracyEstimator:
# self.model.fit(*train.Xy) # self.model.fit(*train.Xy)
if isinstance(train, LabelledCollection): if isinstance(train, LabelledCollection):
pred_prob_train = cross_val_predict( pred_prob_train = cross_val_predict(
self.model, train.Xy, method="predict_proba" self.model, *train.Xy, method="predict_proba"
) )
self.e_train = _extend_collection(train, pred_prob_train) self.e_train = _extend_collection(train, pred_prob_train)
@ -84,5 +84,5 @@ class AccuracyEstimator:
estim_prev = self.q_model.quantify(e_inst) estim_prev = self.q_model.quantify(e_inst)
return _check_prevalence_classes( return _check_prevalence_classes(
e_inst.classes_, self.q_model.classes_, estim_prev self.e_train.classes_, self.q_model.classes_, estim_prev
) )

View File

@ -2,8 +2,12 @@ from quapy.protocol import (
OnLabelledCollectionProtocol, OnLabelledCollectionProtocol,
AbstractStochasticSeededProtocol, AbstractStochasticSeededProtocol,
) )
import quapy as qp
from typing import Iterable, Callable, Union
from .estimator import AccuracyEstimator from .estimator import AccuracyEstimator
import pandas as pd
import quacc.error as error
def estimate(estimator: AccuracyEstimator, protocol: AbstractStochasticSeededProtocol): def estimate(estimator: AccuracyEstimator, protocol: AbstractStochasticSeededProtocol):
@ -21,5 +25,43 @@ def estimate(estimator: AccuracyEstimator, protocol: AbstractStochasticSeededPro
return base_prevs, true_prevs, estim_prevs return base_prevs, true_prevs, estim_prevs
def evaluate(): def evaluation_report(
pass estimator: AccuracyEstimator,
protocol: AbstractStochasticSeededProtocol,
error_metrics: Iterable[Union[str, Callable]] = "all",
):
base_prevs, true_prevs, estim_prevs = estimate(estimator, protocol)
if error_metrics == "all":
error_metrics = ["mae", "rae", "mrae", "kld", "nkld", "f1e"]
error_funcs = [
error.from_name(e) if isinstance(e, str) else e for e in error_metrics
]
assert all(hasattr(e, "__call__") for e in error_funcs), "invalid error function"
error_names = [e.__name__ for e in error_funcs]
df_cols = ["base_prev", "true_prev", "estim_prev"] + error_names
if "f1e" in df_cols:
df_cols.remove("f1e")
df_cols.extend(["f1e_true", "f1e_estim"])
lst = []
for base_prev, true_prev, estim_prev in zip(base_prevs, true_prevs, estim_prevs):
series = {
"base_prev": base_prev,
"true_prev": true_prev,
"estim_prev": estim_prev,
}
for error_name, error_metric in zip(error_names, error_funcs):
if error_name == "f1e":
series["f1e_true"] = error_metric(true_prev)
series["f1e_estim"] = error_metric(estim_prev)
continue
score = error_metric(true_prev, estim_prev)
series[error_name] = score
lst.append(series)
df = pd.DataFrame(lst, columns=df_cols)
return df

View File

@ -2,10 +2,16 @@ import numpy as np
import quapy as qp import quapy as qp
import scipy.sparse as sp import scipy.sparse as sp
from quapy.data import LabelledCollection from quapy.data import LabelledCollection
from quapy.method.aggregative import SLD
from quapy.protocol import APP, AbstractStochasticSeededProtocol from quapy.protocol import APP, AbstractStochasticSeededProtocol
from sklearn.linear_model import LogisticRegression from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import cross_val_predict from sklearn.model_selection import cross_val_predict
import quacc.evaluation as eval
from quacc.estimator import AccuracyEstimator
qp.environ['SAMPLE_SIZE'] = 100
# Extended classes # Extended classes
# #
@ -133,9 +139,9 @@ def test_1(dataset_name):
orig_prevs, true_prevs, estim_prevs, errors = extend_and_quantify( orig_prevs, true_prevs, estim_prevs, errors = extend_and_quantify(
LogisticRegression(), LogisticRegression(),
qp.method.aggregative.SLD(LogisticRegression()), SLD(LogisticRegression()),
train, train,
APP(test, sample_size=100, n_prevalences=11, repeats=1), APP(test, n_prevalences=11, repeats=1),
) )
for orig_prev, true_prev, estim_prev, _errors in zip( for orig_prev, true_prev, estim_prev, _errors in zip(
@ -149,6 +155,18 @@ def test_1(dataset_name):
print() print()
def test_2(dataset_name):
train, test = get_dataset(dataset_name)
model = LogisticRegression()
model.fit(*train.Xy)
estimator = AccuracyEstimator(model, SLD(LogisticRegression()))
estimator.fit(train)
df = eval.evaluation_report(
estimator, APP(test, n_prevalences=11, repeats=1)
)
print(df.to_string())
def main(): def main():
for dataset_name in [ for dataset_name in [
# "hp", # "hp",
@ -156,7 +174,7 @@ def main():
"spambase", "spambase",
]: ]:
print(dataset_name) print(dataset_name)
test_1(dataset_name) test_2(dataset_name)
print("*" * 50) print("*" * 50)