Refactoring completed
This commit is contained in:
parent
6ac18137fa
commit
5959a0d323
|
@ -7,7 +7,6 @@ from typing import List, Optional
|
||||||
class ExtendedCollection(LabelledCollection):
|
class ExtendedCollection(LabelledCollection):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
b_coll: LabelledCollection,
|
|
||||||
instances: np.ndarray | sp.csr_matrix,
|
instances: np.ndarray | sp.csr_matrix,
|
||||||
labels: np.ndarray,
|
labels: np.ndarray,
|
||||||
classes: Optional[List] = None,
|
classes: Optional[List] = None,
|
||||||
|
|
|
@ -0,0 +1,15 @@
|
||||||
|
import quapy as qp
|
||||||
|
|
||||||
|
def from_name(err_name):
|
||||||
|
if err_name == 'f1e':
|
||||||
|
return f1e
|
||||||
|
else:
|
||||||
|
return qp.error.from_name(err_name)
|
||||||
|
|
||||||
|
def f1e(prev):
|
||||||
|
return 1 - f1_score(prev)
|
||||||
|
|
||||||
|
def f1_score(prev):
|
||||||
|
recall = prev[0] / (prev[0] + prev[1])
|
||||||
|
precision = prev[0] / (prev[0] + prev[2])
|
||||||
|
return 2 * (precision * recall) / (precision + recall)
|
|
@ -46,7 +46,7 @@ def _extend_collection(base: LabelledCollection, pred_proba) -> ExtendedCollecti
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
return ExtendedCollection(n_x, n_y, [*range(0, n_classes * n_classes)])
|
return ExtendedCollection(n_x, n_y, classes=[*range(0, n_classes * n_classes)])
|
||||||
|
|
||||||
|
|
||||||
class AccuracyEstimator:
|
class AccuracyEstimator:
|
||||||
|
@ -65,7 +65,7 @@ class AccuracyEstimator:
|
||||||
# self.model.fit(*train.Xy)
|
# self.model.fit(*train.Xy)
|
||||||
if isinstance(train, LabelledCollection):
|
if isinstance(train, LabelledCollection):
|
||||||
pred_prob_train = cross_val_predict(
|
pred_prob_train = cross_val_predict(
|
||||||
self.model, train.Xy, method="predict_proba"
|
self.model, *train.Xy, method="predict_proba"
|
||||||
)
|
)
|
||||||
|
|
||||||
self.e_train = _extend_collection(train, pred_prob_train)
|
self.e_train = _extend_collection(train, pred_prob_train)
|
||||||
|
@ -84,5 +84,5 @@ class AccuracyEstimator:
|
||||||
estim_prev = self.q_model.quantify(e_inst)
|
estim_prev = self.q_model.quantify(e_inst)
|
||||||
|
|
||||||
return _check_prevalence_classes(
|
return _check_prevalence_classes(
|
||||||
e_inst.classes_, self.q_model.classes_, estim_prev
|
self.e_train.classes_, self.q_model.classes_, estim_prev
|
||||||
)
|
)
|
||||||
|
|
|
@ -2,8 +2,12 @@ from quapy.protocol import (
|
||||||
OnLabelledCollectionProtocol,
|
OnLabelledCollectionProtocol,
|
||||||
AbstractStochasticSeededProtocol,
|
AbstractStochasticSeededProtocol,
|
||||||
)
|
)
|
||||||
|
import quapy as qp
|
||||||
|
from typing import Iterable, Callable, Union
|
||||||
|
|
||||||
from .estimator import AccuracyEstimator
|
from .estimator import AccuracyEstimator
|
||||||
|
import pandas as pd
|
||||||
|
import quacc.error as error
|
||||||
|
|
||||||
|
|
||||||
def estimate(estimator: AccuracyEstimator, protocol: AbstractStochasticSeededProtocol):
|
def estimate(estimator: AccuracyEstimator, protocol: AbstractStochasticSeededProtocol):
|
||||||
|
@ -21,5 +25,43 @@ def estimate(estimator: AccuracyEstimator, protocol: AbstractStochasticSeededPro
|
||||||
return base_prevs, true_prevs, estim_prevs
|
return base_prevs, true_prevs, estim_prevs
|
||||||
|
|
||||||
|
|
||||||
def evaluate():
|
def evaluation_report(
|
||||||
pass
|
estimator: AccuracyEstimator,
|
||||||
|
protocol: AbstractStochasticSeededProtocol,
|
||||||
|
error_metrics: Iterable[Union[str, Callable]] = "all",
|
||||||
|
):
|
||||||
|
base_prevs, true_prevs, estim_prevs = estimate(estimator, protocol)
|
||||||
|
|
||||||
|
if error_metrics == "all":
|
||||||
|
error_metrics = ["mae", "rae", "mrae", "kld", "nkld", "f1e"]
|
||||||
|
|
||||||
|
error_funcs = [
|
||||||
|
error.from_name(e) if isinstance(e, str) else e for e in error_metrics
|
||||||
|
]
|
||||||
|
assert all(hasattr(e, "__call__") for e in error_funcs), "invalid error function"
|
||||||
|
error_names = [e.__name__ for e in error_funcs]
|
||||||
|
|
||||||
|
df_cols = ["base_prev", "true_prev", "estim_prev"] + error_names
|
||||||
|
if "f1e" in df_cols:
|
||||||
|
df_cols.remove("f1e")
|
||||||
|
df_cols.extend(["f1e_true", "f1e_estim"])
|
||||||
|
lst = []
|
||||||
|
for base_prev, true_prev, estim_prev in zip(base_prevs, true_prevs, estim_prevs):
|
||||||
|
series = {
|
||||||
|
"base_prev": base_prev,
|
||||||
|
"true_prev": true_prev,
|
||||||
|
"estim_prev": estim_prev,
|
||||||
|
}
|
||||||
|
for error_name, error_metric in zip(error_names, error_funcs):
|
||||||
|
if error_name == "f1e":
|
||||||
|
series["f1e_true"] = error_metric(true_prev)
|
||||||
|
series["f1e_estim"] = error_metric(estim_prev)
|
||||||
|
continue
|
||||||
|
|
||||||
|
score = error_metric(true_prev, estim_prev)
|
||||||
|
series[error_name] = score
|
||||||
|
|
||||||
|
lst.append(series)
|
||||||
|
|
||||||
|
df = pd.DataFrame(lst, columns=df_cols)
|
||||||
|
return df
|
||||||
|
|
|
@ -2,10 +2,16 @@ import numpy as np
|
||||||
import quapy as qp
|
import quapy as qp
|
||||||
import scipy.sparse as sp
|
import scipy.sparse as sp
|
||||||
from quapy.data import LabelledCollection
|
from quapy.data import LabelledCollection
|
||||||
|
from quapy.method.aggregative import SLD
|
||||||
from quapy.protocol import APP, AbstractStochasticSeededProtocol
|
from quapy.protocol import APP, AbstractStochasticSeededProtocol
|
||||||
from sklearn.linear_model import LogisticRegression
|
from sklearn.linear_model import LogisticRegression
|
||||||
from sklearn.model_selection import cross_val_predict
|
from sklearn.model_selection import cross_val_predict
|
||||||
|
|
||||||
|
import quacc.evaluation as eval
|
||||||
|
from quacc.estimator import AccuracyEstimator
|
||||||
|
|
||||||
|
qp.environ['SAMPLE_SIZE'] = 100
|
||||||
|
|
||||||
|
|
||||||
# Extended classes
|
# Extended classes
|
||||||
#
|
#
|
||||||
|
@ -133,9 +139,9 @@ def test_1(dataset_name):
|
||||||
|
|
||||||
orig_prevs, true_prevs, estim_prevs, errors = extend_and_quantify(
|
orig_prevs, true_prevs, estim_prevs, errors = extend_and_quantify(
|
||||||
LogisticRegression(),
|
LogisticRegression(),
|
||||||
qp.method.aggregative.SLD(LogisticRegression()),
|
SLD(LogisticRegression()),
|
||||||
train,
|
train,
|
||||||
APP(test, sample_size=100, n_prevalences=11, repeats=1),
|
APP(test, n_prevalences=11, repeats=1),
|
||||||
)
|
)
|
||||||
|
|
||||||
for orig_prev, true_prev, estim_prev, _errors in zip(
|
for orig_prev, true_prev, estim_prev, _errors in zip(
|
||||||
|
@ -149,6 +155,18 @@ def test_1(dataset_name):
|
||||||
print()
|
print()
|
||||||
|
|
||||||
|
|
||||||
|
def test_2(dataset_name):
|
||||||
|
train, test = get_dataset(dataset_name)
|
||||||
|
model = LogisticRegression()
|
||||||
|
model.fit(*train.Xy)
|
||||||
|
estimator = AccuracyEstimator(model, SLD(LogisticRegression()))
|
||||||
|
estimator.fit(train)
|
||||||
|
df = eval.evaluation_report(
|
||||||
|
estimator, APP(test, n_prevalences=11, repeats=1)
|
||||||
|
)
|
||||||
|
print(df.to_string())
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
for dataset_name in [
|
for dataset_name in [
|
||||||
# "hp",
|
# "hp",
|
||||||
|
@ -156,7 +174,7 @@ def main():
|
||||||
"spambase",
|
"spambase",
|
||||||
]:
|
]:
|
||||||
print(dataset_name)
|
print(dataset_name)
|
||||||
test_1(dataset_name)
|
test_2(dataset_name)
|
||||||
print("*" * 50)
|
print("*" * 50)
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue