forked from moreo/QuaPy
bugfix, some methods modified h
This commit is contained in:
parent
07a29d4b60
commit
8b9b8957f5
|
@ -1,8 +1,8 @@
|
||||||
from ClassifierAccuracy.util.commons import *
|
from ClassifierAccuracy.util.commons import *
|
||||||
from ClassifierAccuracy.util.plotting import plot_diagonal
|
from ClassifierAccuracy.util.plotting import plot_diagonal
|
||||||
|
|
||||||
PROBLEM = 'multiclass'
|
PROBLEM = 'binary'
|
||||||
ORACLE = True
|
ORACLE = False
|
||||||
basedir = PROBLEM+('-oracle' if ORACLE else '')
|
basedir = PROBLEM+('-oracle' if ORACLE else '')
|
||||||
|
|
||||||
|
|
||||||
|
@ -14,6 +14,10 @@ elif PROBLEM == 'multiclass':
|
||||||
qp.environ['SAMPLE_SIZE'] = 250
|
qp.environ['SAMPLE_SIZE'] = 250
|
||||||
NUM_TEST = 1000
|
NUM_TEST = 1000
|
||||||
gen_datasets = gen_multi_datasets
|
gen_datasets = gen_multi_datasets
|
||||||
|
elif PROBLEM == 'tweet':
|
||||||
|
qp.environ['SAMPLE_SIZE'] = 100
|
||||||
|
NUM_TEST = 1000
|
||||||
|
gen_datasets = gen_tweet_datasets
|
||||||
|
|
||||||
|
|
||||||
for (cls_name, h), (dataset_name, (L, V, U)) in itertools.product(gen_classifiers(), gen_datasets()):
|
for (cls_name, h), (dataset_name, (L, V, U)) in itertools.product(gen_classifiers(), gen_datasets()):
|
||||||
|
|
|
@ -65,7 +65,7 @@ class CAPContingencyTable(ClassifierAccuracyPrediction):
|
||||||
the errors in quantification performance
|
the errors in quantification performance
|
||||||
:return: float
|
:return: float
|
||||||
"""
|
"""
|
||||||
cont_table = self.predict_ct(X, oracle)
|
cont_table = self.predict_ct(X, oracle_prev)
|
||||||
raw_acc = self.acc(cont_table)
|
raw_acc = self.acc(cont_table)
|
||||||
norm_acc = np.clip(raw_acc, 0, 1)
|
norm_acc = np.clip(raw_acc, 0, 1)
|
||||||
return norm_acc
|
return norm_acc
|
||||||
|
@ -140,7 +140,7 @@ class ContTableTransferCAP(CAPContingencyTableQ):
|
||||||
def fit(self, val: LabelledCollection):
|
def fit(self, val: LabelledCollection):
|
||||||
y_hat = self.h.predict(val.X)
|
y_hat = self.h.predict(val.X)
|
||||||
y_true = val.y
|
y_true = val.y
|
||||||
self.cont_table = confusion_matrix(y_true, y_pred=y_hat, labels=val.classes_, normalize='all')
|
self.cont_table = confusion_matrix(y_true=y_true, y_pred=y_hat, labels=val.classes_, normalize='all')
|
||||||
self.train_prev = val.prevalence()
|
self.train_prev = val.prevalence()
|
||||||
self.quantifier_fit(val)
|
self.quantifier_fit(val)
|
||||||
return self
|
return self
|
||||||
|
@ -332,7 +332,7 @@ class PabloCAP(ClassifierAccuracyPrediction):
|
||||||
def __init__(self, h, acc_fn, q_class, n_val_samples=100, aggr='mean'):
|
def __init__(self, h, acc_fn, q_class, n_val_samples=100, aggr='mean'):
|
||||||
self.h = h
|
self.h = h
|
||||||
self.acc = acc_fn
|
self.acc = acc_fn
|
||||||
self.q = q_class(h)
|
self.q = q_class(deepcopy(h))
|
||||||
self.n_val_samples = n_val_samples
|
self.n_val_samples = n_val_samples
|
||||||
self.aggr = aggr
|
self.aggr = aggr
|
||||||
assert aggr in ['mean', 'median'], 'unknown aggregation function, use mean or median'
|
assert aggr in ['mean', 'median'], 'unknown aggregation function, use mean or median'
|
||||||
|
|
|
@ -17,7 +17,7 @@ from ClassifierAccuracy.util.tabular import Table
|
||||||
from quapy.method.aggregative import EMQ, ACC, KDEyML
|
from quapy.method.aggregative import EMQ, ACC, KDEyML
|
||||||
|
|
||||||
from quapy.data import LabelledCollection
|
from quapy.data import LabelledCollection
|
||||||
from quapy.data.datasets import fetch_UCIMulticlassLabelledCollection, UCI_MULTICLASS_DATASETS, fetch_lequa2022
|
from quapy.data.datasets import fetch_UCIMulticlassLabelledCollection, UCI_MULTICLASS_DATASETS, fetch_lequa2022, TWITTER_SENTIMENT_DATASETS_TEST
|
||||||
from quapy.data.datasets import fetch_reviews
|
from quapy.data.datasets import fetch_reviews
|
||||||
|
|
||||||
|
|
||||||
|
@ -45,6 +45,9 @@ def gen_multi_datasets(only_names=False)-> [str,[LabelledCollection,LabelledColl
|
||||||
yield dataset_name, split(dataset)
|
yield dataset_name, split(dataset)
|
||||||
|
|
||||||
# yields the 20 newsgroups dataset
|
# yields the 20 newsgroups dataset
|
||||||
|
if only_names:
|
||||||
|
yield "20news", None
|
||||||
|
else:
|
||||||
train = fetch_20newsgroups(subset='train', remove=('headers', 'footers', 'quotes'))
|
train = fetch_20newsgroups(subset='train', remove=('headers', 'footers', 'quotes'))
|
||||||
test = fetch_20newsgroups(subset='test', remove=('headers', 'footers', 'quotes'))
|
test = fetch_20newsgroups(subset='test', remove=('headers', 'footers', 'quotes'))
|
||||||
tfidf = TfidfVectorizer(min_df=5, sublinear_tf=True)
|
tfidf = TfidfVectorizer(min_df=5, sublinear_tf=True)
|
||||||
|
@ -56,10 +59,23 @@ def gen_multi_datasets(only_names=False)-> [str,[LabelledCollection,LabelledColl
|
||||||
yield "20news", (T, V, U)
|
yield "20news", (T, V, U)
|
||||||
|
|
||||||
# yields the T1B@LeQua2022 (training) dataset
|
# yields the T1B@LeQua2022 (training) dataset
|
||||||
|
if only_names:
|
||||||
|
yield "T1B-LeQua2022", None
|
||||||
|
else:
|
||||||
train, _, _ = fetch_lequa2022(task='T1B')
|
train, _, _ = fetch_lequa2022(task='T1B')
|
||||||
yield "T1B-LeQua2022", split(train)
|
yield "T1B-LeQua2022", split(train)
|
||||||
|
|
||||||
|
|
||||||
|
def gen_tweet_datasets(only_names=False)-> [str,[LabelledCollection,LabelledCollection,LabelledCollection]]:
|
||||||
|
for dataset_name in TWITTER_SENTIMENT_DATASETS_TEST:
|
||||||
|
if only_names:
|
||||||
|
yield dataset_name, None
|
||||||
|
else:
|
||||||
|
data = qp.datasets.fetch_twitter(dataset_name, min_df=3, pickle=True)
|
||||||
|
T, V = data.training.split_stratified(0.5, random_state=0)
|
||||||
|
U = data.test
|
||||||
|
yield dataset_name, (T, V, U)
|
||||||
|
|
||||||
|
|
||||||
def gen_bin_datasets(only_names=False) -> [str,[LabelledCollection,LabelledCollection,LabelledCollection]]:
|
def gen_bin_datasets(only_names=False) -> [str,[LabelledCollection,LabelledCollection,LabelledCollection]]:
|
||||||
if only_names:
|
if only_names:
|
||||||
|
@ -104,7 +120,7 @@ def gen_CAP_cont_table(h)->[str,CAPContingencyTable]:
|
||||||
#yield 'QuAcc(EMQ)nxn-noX', QuAccNxN(h, acc_fn, EMQ(LogisticRegression()), add_posteriors=True, add_X=False)
|
#yield 'QuAcc(EMQ)nxn-noX', QuAccNxN(h, acc_fn, EMQ(LogisticRegression()), add_posteriors=True, add_X=False)
|
||||||
#yield 'QuAcc(EMQ)nxn', QuAccNxN(h, acc_fn, EMQ(LogisticRegression()))
|
#yield 'QuAcc(EMQ)nxn', QuAccNxN(h, acc_fn, EMQ(LogisticRegression()))
|
||||||
#yield 'QuAcc(EMQ)nxn-MC', QuAccNxN(h, acc_fn, EMQ(LogisticRegression()), add_maxconf=True)
|
#yield 'QuAcc(EMQ)nxn-MC', QuAccNxN(h, acc_fn, EMQ(LogisticRegression()), add_maxconf=True)
|
||||||
yield 'QuAcc(EMQ)nxn-NE', QuAccNxN(h, acc_fn, EMQ(LogisticRegression()), add_negentropy=True)
|
# yield 'QuAcc(EMQ)nxn-NE', QuAccNxN(h, acc_fn, EMQ(LogisticRegression()), add_negentropy=True)
|
||||||
#yield 'QuAcc(EMQ)nxn-MIS', QuAccNxN(h, acc_fn, EMQ(LogisticRegression()), add_maxinfsoft=True)
|
#yield 'QuAcc(EMQ)nxn-MIS', QuAccNxN(h, acc_fn, EMQ(LogisticRegression()), add_maxinfsoft=True)
|
||||||
#yield 'QuAcc(EMQ)1xn2', QuAcc1xN2(h, acc_fn, EMQ(LogisticRegression()))
|
#yield 'QuAcc(EMQ)1xn2', QuAcc1xN2(h, acc_fn, EMQ(LogisticRegression()))
|
||||||
#yield 'QuAcc(EMQ)1xn2', QuAcc1xN2(h, acc_fn, EMQ(LogisticRegression()))
|
#yield 'QuAcc(EMQ)1xn2', QuAcc1xN2(h, acc_fn, EMQ(LogisticRegression()))
|
||||||
|
|
|
@ -123,7 +123,7 @@ class LabelledCollection:
|
||||||
if len(prevs) == self.n_classes - 1:
|
if len(prevs) == self.n_classes - 1:
|
||||||
prevs = prevs + (1 - sum(prevs),)
|
prevs = prevs + (1 - sum(prevs),)
|
||||||
assert len(prevs) == self.n_classes, 'unexpected number of prevalences'
|
assert len(prevs) == self.n_classes, 'unexpected number of prevalences'
|
||||||
assert sum(prevs) == 1, f'prevalences ({prevs}) wrong range (sum={sum(prevs)})'
|
assert np.isclose(sum(prevs), 1), f'prevalences ({prevs}) wrong range (sum={sum(prevs)})'
|
||||||
|
|
||||||
# Decide how many instances should be taken for each class in order to satisfy the requested prevalence
|
# Decide how many instances should be taken for each class in order to satisfy the requested prevalence
|
||||||
# accurately, and the number of instances in the sample (exactly). If int(size * prevs[i]) (which is
|
# accurately, and the number of instances in the sample (exactly). If int(size * prevs[i]) (which is
|
||||||
|
|
Loading…
Reference in New Issue