adding scmq
This commit is contained in:
parent
88541976e9
commit
24c28edfd9
|
|
@ -0,0 +1,36 @@
|
||||||
|
from sklearn.linear_model import LogisticRegression
|
||||||
|
from statsmodels.sandbox.distributions.genpareto import quant
|
||||||
|
|
||||||
|
import quapy as qp
|
||||||
|
from quapy.protocol import UPP
|
||||||
|
from quapy.method.aggregative import PACC, DMy, EMQ, KDEyML
|
||||||
|
from quapy.method.meta import SCMQ
|
||||||
|
|
||||||
|
qp.environ["SAMPLE_SIZE"]=100
|
||||||
|
|
||||||
|
def train_and_test_model(quantifier, train, test):
|
||||||
|
quantifier.fit(train)
|
||||||
|
report = qp.evaluation.evaluation_report(quantifier, UPP(test), error_metrics=['mae', 'mrae'])
|
||||||
|
print(quantifier.__class__.__name__)
|
||||||
|
print(report.mean(numeric_only=True))
|
||||||
|
|
||||||
|
|
||||||
|
quantifiers = [
|
||||||
|
PACC(),
|
||||||
|
DMy(),
|
||||||
|
EMQ(),
|
||||||
|
KDEyML()
|
||||||
|
]
|
||||||
|
|
||||||
|
classifier = LogisticRegression()
|
||||||
|
|
||||||
|
dataset_name = qp.datasets.UCI_MULTICLASS_DATASETS[0]
|
||||||
|
data = qp.datasets.fetch_UCIMulticlassDataset(dataset_name)
|
||||||
|
train, test = data.train_test
|
||||||
|
|
||||||
|
scmq = SCMQ(classifier, quantifiers)
|
||||||
|
|
||||||
|
train_and_test_model(scmq, train, test)
|
||||||
|
|
||||||
|
for quantifier in quantifiers:
|
||||||
|
train_and_test_model(quantifier, train, test)
|
||||||
|
|
@ -591,7 +591,6 @@ class PACC(AggregativeSoftQuantifier):
|
||||||
if self.norm not in ACC.NORMALIZATIONS:
|
if self.norm not in ACC.NORMALIZATIONS:
|
||||||
raise ValueError(f"unknown normalization; valid ones are {ACC.NORMALIZATIONS}")
|
raise ValueError(f"unknown normalization; valid ones are {ACC.NORMALIZATIONS}")
|
||||||
|
|
||||||
|
|
||||||
def aggregation_fit(self, classif_predictions: LabelledCollection, data: LabelledCollection):
|
def aggregation_fit(self, classif_predictions: LabelledCollection, data: LabelledCollection):
|
||||||
"""
|
"""
|
||||||
Estimates the misclassification rates
|
Estimates the misclassification rates
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
import itertools
|
import itertools
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from typing import Union
|
from typing import Union, List
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from sklearn.linear_model import LogisticRegression
|
from sklearn.linear_model import LogisticRegression
|
||||||
from sklearn.metrics import f1_score, make_scorer, accuracy_score
|
from sklearn.metrics import f1_score, make_scorer, accuracy_score
|
||||||
|
|
@ -12,7 +12,7 @@ from quapy import functional as F
|
||||||
from quapy.data import LabelledCollection
|
from quapy.data import LabelledCollection
|
||||||
from quapy.model_selection import GridSearchQ
|
from quapy.model_selection import GridSearchQ
|
||||||
from quapy.method.base import BaseQuantifier, BinaryQuantifier
|
from quapy.method.base import BaseQuantifier, BinaryQuantifier
|
||||||
from quapy.method.aggregative import CC, ACC, PACC, HDy, EMQ, AggregativeQuantifier
|
from quapy.method.aggregative import CC, ACC, PACC, HDy, EMQ, AggregativeQuantifier, AggregativeSoftQuantifier
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from . import _neural
|
from . import _neural
|
||||||
|
|
@ -691,3 +691,66 @@ def EEMQ(classifier, param_grid=None, optim=None, param_mod_sel=None, **kwargs):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
return ensembleFactory(classifier, EMQ, param_grid, optim, param_mod_sel, **kwargs)
|
return ensembleFactory(classifier, EMQ, param_grid, optim, param_mod_sel, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class SCMQ(AggregativeSoftQuantifier):
|
||||||
|
|
||||||
|
MERGE_FUNCTIONS = ['median']
|
||||||
|
|
||||||
|
def __init__(self, classifier, quantifiers: List[AggregativeSoftQuantifier], merge_fun='median', val_split=5):
|
||||||
|
self.classifier = classifier
|
||||||
|
self.quantifiers = quantifiers
|
||||||
|
assert merge_fun in self.MERGE_FUNCTIONS, f'unknwon {merge_fun=}, valid ones are {self.MERGE_FUNCTIONS}'
|
||||||
|
self.merge_fun = merge_fun
|
||||||
|
self.val_split = val_split
|
||||||
|
|
||||||
|
def aggregation_fit(self, classif_predictions: LabelledCollection, data: LabelledCollection):
|
||||||
|
for quantifier in self.quantifiers:
|
||||||
|
quantifier.classifier = self.classifier
|
||||||
|
quantifier.aggregation_fit(classif_predictions, data)
|
||||||
|
return self
|
||||||
|
|
||||||
|
def aggregate(self, classif_predictions: np.ndarray):
|
||||||
|
prev_predictions = []
|
||||||
|
for quantifier_i in self.quantifiers:
|
||||||
|
prevalence_i = quantifier_i.aggregate(classif_predictions)
|
||||||
|
prev_predictions.append(prevalence_i)
|
||||||
|
return self.merge(prev_predictions)
|
||||||
|
|
||||||
|
def merge(self, prev_predictions):
|
||||||
|
prev_predictions = np.asarray(prev_predictions)
|
||||||
|
if self.merge_fun == 'median':
|
||||||
|
prevalences = np.median(prev_predictions, axis=0)
|
||||||
|
prevalences = F.normalize_prevalence(prevalences, method='l1')
|
||||||
|
elif self.merge_fun == 'mean':
|
||||||
|
prevalences = np.mean(prev_predictions, axis=0)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(f'merge function {self.merge_fun} not implemented!')
|
||||||
|
return prevalences
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue