generating tables with captions, added 20 newsgroups and lequa 2022 t1b
This commit is contained in:
parent
d81bf305a3
commit
07a29d4b60
|
@ -2,7 +2,7 @@ from ClassifierAccuracy.util.commons import *
|
|||
from ClassifierAccuracy.util.plotting import plot_diagonal
|
||||
|
||||
PROBLEM = 'multiclass'
|
||||
ORACLE = False
|
||||
ORACLE = True
|
||||
basedir = PROBLEM+('-oracle' if ORACLE else '')
|
||||
|
||||
|
||||
|
|
|
@ -17,7 +17,7 @@ from ClassifierAccuracy.util.tabular import Table
|
|||
from quapy.method.aggregative import EMQ, ACC, KDEyML
|
||||
|
||||
from quapy.data import LabelledCollection
|
||||
from quapy.data.datasets import fetch_UCIMulticlassLabelledCollection, UCI_MULTICLASS_DATASETS
|
||||
from quapy.data.datasets import fetch_UCIMulticlassLabelledCollection, UCI_MULTICLASS_DATASETS, fetch_lequa2022
|
||||
from quapy.data.datasets import fetch_reviews
|
||||
|
||||
|
||||
|
@ -43,6 +43,8 @@ def gen_multi_datasets(only_names=False)-> [str,[LabelledCollection,LabelledColl
|
|||
else:
|
||||
dataset = fetch_UCIMulticlassLabelledCollection(dataset_name)
|
||||
yield dataset_name, split(dataset)
|
||||
|
||||
# yields the 20 newsgroups dataset
|
||||
train = fetch_20newsgroups(subset='train', remove=('headers', 'footers', 'quotes'))
|
||||
test = fetch_20newsgroups(subset='test', remove=('headers', 'footers', 'quotes'))
|
||||
tfidf = TfidfVectorizer(min_df=5, sublinear_tf=True)
|
||||
|
@ -53,6 +55,10 @@ def gen_multi_datasets(only_names=False)-> [str,[LabelledCollection,LabelledColl
|
|||
T, V = train.split_stratified(train_prop=0.5, random_state=0)
|
||||
yield "20news", (T, V, U)
|
||||
|
||||
# yields the T1B@LeQua2022 (training) dataset
|
||||
train, _, _ = fetch_lequa2022(task='T1B')
|
||||
yield "T1B-LeQua2022", split(train)
|
||||
|
||||
|
||||
|
||||
def gen_bin_datasets(only_names=False) -> [str,[LabelledCollection,LabelledCollection,LabelledCollection]]:
|
||||
|
@ -92,7 +98,7 @@ def gen_CAP(h, acc_fn, with_oracle=False)->[str, ClassifierAccuracyPrediction]:
|
|||
def gen_CAP_cont_table(h)->[str,CAPContingencyTable]:
|
||||
acc_fn = None
|
||||
yield 'Naive', NaiveCAP(h, acc_fn)
|
||||
#yield 'CT-PPS-EMQ', ContTableTransferCAP(h, acc_fn, EMQ(LogisticRegression()))
|
||||
yield 'CT-PPS-EMQ', ContTableTransferCAP(h, acc_fn, EMQ(LogisticRegression()))
|
||||
#yield 'CT-PPS-KDE', ContTableTransferCAP(h, acc_fn, KDEyML(LogisticRegression(class_weight='balanced'), bandwidth=0.01))
|
||||
yield 'CT-PPS-KDE05', ContTableTransferCAP(h, acc_fn, KDEyML(LogisticRegression(class_weight='balanced'), bandwidth=0.05))
|
||||
#yield 'QuAcc(EMQ)nxn-noX', QuAccNxN(h, acc_fn, EMQ(LogisticRegression()), add_posteriors=True, add_X=False)
|
||||
|
@ -308,6 +314,8 @@ def gen_tables(basedir, datasets):
|
|||
|
||||
os.makedirs('./tables', exist_ok=True)
|
||||
|
||||
with_oracle = 'oracle' in basedir
|
||||
|
||||
tex_doc = """
|
||||
\\documentclass[10pt,a4paper]{article}
|
||||
\\usepackage[utf8]{inputenc}
|
||||
|
@ -322,40 +330,45 @@ def gen_tables(basedir, datasets):
|
|||
\\begin{document}
|
||||
"""
|
||||
|
||||
classifier = classifiers[0]
|
||||
for metric in [measure for measure, _ in gen_acc_measure()]:
|
||||
for classifier in classifiers:
|
||||
for metric in [measure for measure, _ in gen_acc_measure()]:
|
||||
|
||||
table = Table(datasets, methods, prec_mean=5, clean_zero=True)
|
||||
for method, dataset in itertools.product(methods, datasets):
|
||||
path = getpath(basedir, classifier, metric, dataset, method)
|
||||
if not os.path.exists(path):
|
||||
print('missing ', path)
|
||||
continue
|
||||
results = json.load(open(path, 'r'))
|
||||
true_acc = results['true_acc']
|
||||
estim_acc = np.asarray(results['estim_acc'])
|
||||
if any(np.isnan(estim_acc)):
|
||||
print(f'nan values found in {method=} {dataset=}')
|
||||
continue
|
||||
if any(estim_acc>1.00001):
|
||||
print(f'values >1 found in {method=} {dataset=} [max={estim_acc.max()}]')
|
||||
continue
|
||||
if any(estim_acc<-0.00001):
|
||||
print(f'values <0 found in {method=} {dataset=} [min={estim_acc.min()}]')
|
||||
continue
|
||||
errors = cap_errors(true_acc, estim_acc)
|
||||
table.add(dataset, method, errors)
|
||||
table = Table(datasets, methods, prec_mean=5, clean_zero=True)
|
||||
for method, dataset in itertools.product(methods, datasets):
|
||||
path = getpath(basedir, classifier, metric, dataset, method)
|
||||
if not os.path.exists(path):
|
||||
print('missing ', path)
|
||||
continue
|
||||
results = json.load(open(path, 'r'))
|
||||
true_acc = results['true_acc']
|
||||
estim_acc = np.asarray(results['estim_acc'])
|
||||
if any(np.isnan(estim_acc)):
|
||||
print(f'nan values found in {method=} {dataset=}')
|
||||
continue
|
||||
if any(estim_acc>1.00001):
|
||||
print(f'values >1 found in {method=} {dataset=} [max={estim_acc.max()}]')
|
||||
continue
|
||||
if any(estim_acc<-0.00001):
|
||||
print(f'values <0 found in {method=} {dataset=} [min={estim_acc.min()}]')
|
||||
continue
|
||||
errors = cap_errors(true_acc, estim_acc)
|
||||
table.add(dataset, method, errors)
|
||||
|
||||
tex = table.latexTabular()
|
||||
table_name = f'{basedir}_{classifier}_{metric}.tex'
|
||||
with open(f'./tables/{table_name}', 'wt') as foo:
|
||||
foo.write('\\resizebox{\\textwidth}{!}{%\n')
|
||||
foo.write('\\begin{tabular}{c|'+('c'*len(methods))+'}\n')
|
||||
foo.write(tex)
|
||||
foo.write('\\end{tabular}%\n')
|
||||
foo.write('}\n')
|
||||
tex = table.latexTabular()
|
||||
table_name = f'{basedir}_{classifier}_{metric}.tex'
|
||||
with open(f'./tables/{table_name}', 'wt') as foo:
|
||||
foo.write('\\begin{table}[h]\n')
|
||||
foo.write('\\centering\n')
|
||||
foo.write('\\resizebox{\\textwidth}{!}{%\n')
|
||||
foo.write('\\begin{tabular}{c|'+('c'*len(methods))+'}\n')
|
||||
foo.write(tex)
|
||||
foo.write('\\end{tabular}%\n')
|
||||
foo.write('}\n')
|
||||
foo.write('\\caption{Classifier ' + classifier.replace('_', ' ') + ('(oracle)' if with_oracle else '') +
|
||||
' evaluated in terms of ' + metric.replace('_', ' ') + '}\n')
|
||||
foo.write('\\end{table}\n')
|
||||
|
||||
tex_doc += "\input{" + table_name + "}\n\n"
|
||||
tex_doc += "\input{" + table_name + "}\n\n"
|
||||
|
||||
tex_doc += """
|
||||
\\end{document}
|
||||
|
@ -368,3 +381,4 @@ def gen_tables(basedir, datasets):
|
|||
os.system('pdflatex main.tex')
|
||||
os.system('rm main.aux main.log')
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue