diff --git a/Retrieval/commons.py b/Retrieval/commons.py index d9077e4..79bf029 100644 --- a/Retrieval/commons.py +++ b/Retrieval/commons.py @@ -63,7 +63,7 @@ class TextRankings: O = self.obj docs_ids = [doc_id for doc_id, query_id in O['qid'].items() if query_id == sample_id] texts = [O['text'][doc_id] for doc_id in docs_ids] - labels = [O['continent'][doc_id] for doc_id in docs_ids] + labels = [O[self.class_name][doc_id] for doc_id in docs_ids] if max_lines > 0 and len(texts) > max_lines: ranks = [int(O['rank'][doc_id]) for doc_id in docs_ids] sel = np.argsort(ranks)[:max_lines] diff --git a/Retrieval/fifth.py b/Retrieval/fifth.py index 790ffca..af06b2b 100644 --- a/Retrieval/fifth.py +++ b/Retrieval/fifth.py @@ -104,26 +104,33 @@ RANK_AT_K = -1 REDUCE_TR = 50000 qp.environ['SAMPLE_SIZE'] = RANK_AT_K -data_path = './newExperimentalSetup' -train_path = join(data_path, 'train3000samples.json') +data_path = { + 'first_letter_category': './first_letter_categoryDataset', + 'continent': './newExperimentalSetup' +} + +def scape_latex(string): + return string.replace('_', '\_') Ks = [10, 50, 100, 250, 500, 1000, 2000] # Ks = [500] -for CLASS_NAME in ['continent']: #, 'gender', 'gender_category', 'occupations', 'source_countries', 'source_subcont_regions', 'years_category', 'relative_pageviews_category']: +for CLASS_NAME in ['first_letter_category']: #['continent']: #, 'gender', 'gender_category', 'occupations', 'source_countries', 'source_subcont_regions', 'years_category', 'relative_pageviews_category']: + + train_path = join(data_path[CLASS_NAME], 'train3000samples.json') tfidf, classifier_trained = qp.util.pickled_resource(f'classifier_{CLASS_NAME}.pkl', train_classifier) trained=True - experiment_prot = RetrievedSamples(data_path, + experiment_prot = RetrievedSamples(data_path[CLASS_NAME], load_fn=load_json_sample, vectorizer=tfidf, max_train_lines=None, max_test_lines=RANK_AT_K, classes=classifier_trained.classes_, class_name=CLASS_NAME) method_names = [name for name, *other in methods()] - benchmarks = [f'{CLASS_NAME}@{k}' for k in Ks] + benchmarks = [f'{scape_latex(CLASS_NAME)}@{k}' for k in Ks] table_mae = Table(benchmarks, method_names, color_mode='global') table_mrae = Table(benchmarks, method_names, color_mode='global') @@ -158,8 +165,9 @@ for CLASS_NAME in ['continent']: #, 'gender', 'gender_category', 'occupations', pbar.set_description(f'{method_name}') for k in Ks: - table_mae.add(benchmark=f'{CLASS_NAME}@{k}', method=method_name, values=mae_errors[k]) - table_mrae.add(benchmark=f'{CLASS_NAME}@{k}', method=method_name, values=mrae_errors[k]) + + table_mae.add(benchmark=f'{scape_latex(CLASS_NAME)}@{k}', method=method_name, values=mae_errors[k]) + table_mrae.add(benchmark=f'{scape_latex(CLASS_NAME)}@{k}', method=method_name, values=mrae_errors[k]) table_mae.latexPDF('./latex', 'table_mae.tex') table_mrae.latexPDF('./latex', 'table_mrae.tex')