adding scmq
This commit is contained in:
parent
88541976e9
commit
24c28edfd9
|
|
@ -0,0 +1,36 @@
|
|||
from sklearn.linear_model import LogisticRegression
|
||||
from statsmodels.sandbox.distributions.genpareto import quant
|
||||
|
||||
import quapy as qp
|
||||
from quapy.protocol import UPP
|
||||
from quapy.method.aggregative import PACC, DMy, EMQ, KDEyML
|
||||
from quapy.method.meta import SCMQ
|
||||
|
||||
qp.environ["SAMPLE_SIZE"]=100
|
||||
|
||||
def train_and_test_model(quantifier, train, test):
|
||||
quantifier.fit(train)
|
||||
report = qp.evaluation.evaluation_report(quantifier, UPP(test), error_metrics=['mae', 'mrae'])
|
||||
print(quantifier.__class__.__name__)
|
||||
print(report.mean(numeric_only=True))
|
||||
|
||||
|
||||
quantifiers = [
|
||||
PACC(),
|
||||
DMy(),
|
||||
EMQ(),
|
||||
KDEyML()
|
||||
]
|
||||
|
||||
classifier = LogisticRegression()
|
||||
|
||||
dataset_name = qp.datasets.UCI_MULTICLASS_DATASETS[0]
|
||||
data = qp.datasets.fetch_UCIMulticlassDataset(dataset_name)
|
||||
train, test = data.train_test
|
||||
|
||||
scmq = SCMQ(classifier, quantifiers)
|
||||
|
||||
train_and_test_model(scmq, train, test)
|
||||
|
||||
for quantifier in quantifiers:
|
||||
train_and_test_model(quantifier, train, test)
|
||||
|
|
@ -591,7 +591,6 @@ class PACC(AggregativeSoftQuantifier):
|
|||
if self.norm not in ACC.NORMALIZATIONS:
|
||||
raise ValueError(f"unknown normalization; valid ones are {ACC.NORMALIZATIONS}")
|
||||
|
||||
|
||||
def aggregation_fit(self, classif_predictions: LabelledCollection, data: LabelledCollection):
|
||||
"""
|
||||
Estimates the misclassification rates
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
import itertools
|
||||
from copy import deepcopy
|
||||
from typing import Union
|
||||
from typing import Union, List
|
||||
import numpy as np
|
||||
from sklearn.linear_model import LogisticRegression
|
||||
from sklearn.metrics import f1_score, make_scorer, accuracy_score
|
||||
|
|
@ -12,7 +12,7 @@ from quapy import functional as F
|
|||
from quapy.data import LabelledCollection
|
||||
from quapy.model_selection import GridSearchQ
|
||||
from quapy.method.base import BaseQuantifier, BinaryQuantifier
|
||||
from quapy.method.aggregative import CC, ACC, PACC, HDy, EMQ, AggregativeQuantifier
|
||||
from quapy.method.aggregative import CC, ACC, PACC, HDy, EMQ, AggregativeQuantifier, AggregativeSoftQuantifier
|
||||
|
||||
try:
|
||||
from . import _neural
|
||||
|
|
@ -691,3 +691,66 @@ def EEMQ(classifier, param_grid=None, optim=None, param_mod_sel=None, **kwargs):
|
|||
"""
|
||||
|
||||
return ensembleFactory(classifier, EMQ, param_grid, optim, param_mod_sel, **kwargs)
|
||||
|
||||
|
||||
class SCMQ(AggregativeSoftQuantifier):
|
||||
|
||||
MERGE_FUNCTIONS = ['median']
|
||||
|
||||
def __init__(self, classifier, quantifiers: List[AggregativeSoftQuantifier], merge_fun='median', val_split=5):
|
||||
self.classifier = classifier
|
||||
self.quantifiers = quantifiers
|
||||
assert merge_fun in self.MERGE_FUNCTIONS, f'unknwon {merge_fun=}, valid ones are {self.MERGE_FUNCTIONS}'
|
||||
self.merge_fun = merge_fun
|
||||
self.val_split = val_split
|
||||
|
||||
def aggregation_fit(self, classif_predictions: LabelledCollection, data: LabelledCollection):
|
||||
for quantifier in self.quantifiers:
|
||||
quantifier.classifier = self.classifier
|
||||
quantifier.aggregation_fit(classif_predictions, data)
|
||||
return self
|
||||
|
||||
def aggregate(self, classif_predictions: np.ndarray):
|
||||
prev_predictions = []
|
||||
for quantifier_i in self.quantifiers:
|
||||
prevalence_i = quantifier_i.aggregate(classif_predictions)
|
||||
prev_predictions.append(prevalence_i)
|
||||
return self.merge(prev_predictions)
|
||||
|
||||
def merge(self, prev_predictions):
|
||||
prev_predictions = np.asarray(prev_predictions)
|
||||
if self.merge_fun == 'median':
|
||||
prevalences = np.median(prev_predictions, axis=0)
|
||||
prevalences = F.normalize_prevalence(prevalences, method='l1')
|
||||
elif self.merge_fun == 'mean':
|
||||
prevalences = np.mean(prev_predictions, axis=0)
|
||||
else:
|
||||
raise NotImplementedError(f'merge function {self.merge_fun} not implemented!')
|
||||
return prevalences
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue