final plots
This commit is contained in:
parent
366020d45c
commit
5284e04c90
|
@ -54,7 +54,7 @@ To evaluate our approach, I have executed the queries on the test split. You can
|
|||
"""
|
||||
|
||||
|
||||
def methods(classifier, class_name, binarize=False):
|
||||
def methods(classifier, class_name=None, binarize=False):
|
||||
|
||||
kde_param = {
|
||||
'continent': 0.01,
|
||||
|
@ -75,7 +75,7 @@ def methods(classifier, class_name, binarize=False):
|
|||
# yield ('EMQ-TS', EMQ(classifier, exact_train_prev=False, recalib='ts'))
|
||||
# yield ('EMQ-NBVS', EMQ(classifier, exact_train_prev=False, recalib='nbvs'))
|
||||
# yield ('EMQ-VS', EMQ(classifier, exact_train_prev=False, recalib='vs'))
|
||||
yield ('KDEy-ML', KDEyML(classifier, val_split=5, n_jobs=-1, bandwidth=kde_param[class_name]))
|
||||
yield ('KDEy-ML', KDEyML(classifier, val_split=5, n_jobs=-1, bandwidth=kde_param.get(class_name, 0.01)))
|
||||
# yield ('KDE01', KDEyML(classifier, val_split=5, n_jobs=-1, bandwidth=0.01))
|
||||
if binarize:
|
||||
yield ('M3b', M3rND_ModelB(classifier))
|
||||
|
@ -135,9 +135,12 @@ def reduceAtK(data: LabelledCollection, k):
|
|||
return LabelledCollection(X, y, classes=data.classes_)
|
||||
|
||||
|
||||
def benchmark_name(class_name, k):
|
||||
def benchmark_name(class_name, k=None):
|
||||
scape_class_name = class_name.replace('_', '\_')
|
||||
return f'{scape_class_name}@{k}'
|
||||
if k is None:
|
||||
return scape_class_name
|
||||
else:
|
||||
return f'{scape_class_name}@{k}'
|
||||
|
||||
|
||||
def run_experiment():
|
||||
|
@ -154,8 +157,7 @@ def run_experiment():
|
|||
Xtr, ytr, score_tr = train
|
||||
Xte, yte, score_te = test
|
||||
|
||||
n = len(ytr) // 2
|
||||
train_col = LabelledCollection(Xtr[:n], ytr[:n], classes=classifier.classes_)
|
||||
train_col = LabelledCollection(Xtr, ytr, classes=classifier.classes_)
|
||||
|
||||
if method_name not in ['Naive', 'NaiveQuery', 'M3b', 'M3b+', 'M3d', 'M3d+']:
|
||||
method.fit(train_col, val_split=train_col, fit_classifier=False)
|
||||
|
@ -214,8 +216,10 @@ def run_experiment():
|
|||
|
||||
if isinstance(method, AbstractM3rND):
|
||||
if method_name.endswith('+'):
|
||||
# learns the correction parameters from the query-specific training data
|
||||
conf_matrix_ = method.get_confusion_matrix(*train_col.Xy)
|
||||
else:
|
||||
# learns the correction parameters from the training data used to train the classifier
|
||||
conf_matrix_ = conf_matrix.copy()
|
||||
rND_estim = method.fair_measure_correction(rND_estim, conf_matrix_)
|
||||
|
||||
|
@ -241,11 +245,17 @@ protected_group = {
|
|||
|
||||
if __name__ == '__main__':
|
||||
|
||||
# final tables only contain the information for the data size 10K, each row is a class name and each colum
|
||||
# the corresponding rND (for binary) or rKL (for multiclass) score
|
||||
tables_RND, tables_DKL = [], []
|
||||
tables_final = []
|
||||
for class_mode in ['binary', 'multiclass']:
|
||||
BINARIZE = (class_mode=='binary')
|
||||
method_names = [name for name, *other in methods(None, 'continent', BINARIZE)]
|
||||
|
||||
method_names = [name for name, *other in methods(None, binarize=BINARIZE)]
|
||||
|
||||
table_final = Table(name=f'rND' if BINARIZE else f'rKL', benchmarks=[benchmark_name(c) for c in CLASS_NAMES], methods=method_names)
|
||||
table_final.format.mean_macro = False
|
||||
tables_final.append(table_final)
|
||||
for class_name in CLASS_NAMES:
|
||||
tables_mae, tables_mrae = [], []
|
||||
|
||||
|
@ -298,6 +308,10 @@ if __name__ == '__main__':
|
|||
if BINARIZE:
|
||||
table_RND.add(benchmark=benchmark_name(class_name, data_size), method=method_name, v=results['rND_error'])
|
||||
|
||||
if data_size=='10K':
|
||||
value = results['rND_error'] if BINARIZE else results['rKL_error']
|
||||
table_final.add(benchmark=benchmark_name(class_name), method=method_name, v=value)
|
||||
|
||||
tables = ([table_RND] + tables_mrae) if BINARIZE else ([table_DKL] + tables_mrae)
|
||||
Table.LatexPDF(f'./latex/{class_mode}/{class_name}.pdf', tables=tables)
|
||||
|
||||
|
@ -307,6 +321,7 @@ if __name__ == '__main__':
|
|||
tables_DKL.append(table_DKL)
|
||||
|
||||
Table.LatexPDF(f'./latex/global/main.pdf', tables=tables_RND+tables_DKL, dedicated_pages=False)
|
||||
Table.LatexPDF(f'./latex/final/main.pdf', tables=tables_final, dedicated_pages=False)
|
||||
|
||||
|
||||
|
||||
|
|
|
@ -78,6 +78,8 @@ results = load_all_results()
|
|||
for class_name in CLASS_NAME:
|
||||
for data_size in DATA_SIZE:
|
||||
|
||||
log = True
|
||||
|
||||
fig, ax = plt.subplots()
|
||||
|
||||
max_means = []
|
||||
|
@ -99,10 +101,12 @@ for class_name in CLASS_NAME:
|
|||
|
||||
line = ax.plot(Ks, means, 'o-', label=method_name, color=None)
|
||||
color = line[-1].get_color()
|
||||
if log:
|
||||
ax.set_yscale('log')
|
||||
# ax.fill_between(Ks, means - stds, means + stds, alpha=0.3, color=color)
|
||||
|
||||
ax.set_xlabel('k')
|
||||
ax.set_ylabel('RAE')
|
||||
ax.set_ylabel('RAE' + ('(log scale)' if log else ''))
|
||||
ax.set_title(f'{class_name} from {data_size}')
|
||||
ax.set_ylim([0, max(max_means)*1.05])
|
||||
|
||||
|
|
|
@ -28,9 +28,10 @@ import matplotlib.pyplot as plt
|
|||
data_home = 'data'
|
||||
class_mode = 'multiclass'
|
||||
|
||||
method_names = [name for name, *other in methods(None, 'continent')]
|
||||
method_names = [name for name, *other in methods(None)]
|
||||
|
||||
Ks = [5, 10, 25, 50, 75, 100, 250, 500, 750, 1000]
|
||||
# Ks = [5, 10, 25, 50, 75, 100, 250, 500, 750, 1000]
|
||||
Ks = [50, 100, 500, 1000]
|
||||
DATA_SIZE = ['10K', '50K', '100K', '500K', '1M', 'FULL']
|
||||
CLASS_NAME = ['gender', 'continent', 'years_category']
|
||||
all_results = {}
|
||||
|
@ -46,6 +47,8 @@ results = load_all_results()
|
|||
for class_name in CLASS_NAME:
|
||||
for k in Ks:
|
||||
|
||||
log = True
|
||||
|
||||
fig, ax = plt.subplots()
|
||||
|
||||
max_means = []
|
||||
|
@ -60,15 +63,18 @@ for class_name in CLASS_NAME:
|
|||
# max_mean = np.max([
|
||||
# results[class_name][data_size][method_name][k]['max'] for data_size in DATA_SIZE
|
||||
# ])
|
||||
|
||||
max_means.append(max(means))
|
||||
|
||||
style = 'o-' if method_name != 'CC' else '--'
|
||||
line = ax.plot(DATA_SIZE, means, style, label=method_name, color=None)
|
||||
color = line[-1].get_color()
|
||||
if log:
|
||||
ax.set_yscale('log')
|
||||
# ax.fill_between(Ks, means - stds, means + stds, alpha=0.3, color=color)
|
||||
|
||||
ax.set_xlabel('training pool size')
|
||||
ax.set_ylabel('RAE')
|
||||
ax.set_ylabel('RAE' + ('(log scale)' if log else ''))
|
||||
ax.set_title(f'{class_name} from {k=}')
|
||||
ax.set_ylim([0, max(max_means)*1.05])
|
||||
|
||||
|
|
Loading…
Reference in New Issue