From 8e15678c36398671678d5e107ce77601c0105f88 Mon Sep 17 00:00:00 2001 From: Alex Moreo Date: Wed, 24 Nov 2021 11:21:40 +0100 Subject: [PATCH] adding baselines t2 from simple tfidf vectorization --- LeQua2022/baselines_T2.py | 105 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 105 insertions(+) create mode 100644 LeQua2022/baselines_T2.py diff --git a/LeQua2022/baselines_T2.py b/LeQua2022/baselines_T2.py new file mode 100644 index 0000000..ca70b03 --- /dev/null +++ b/LeQua2022/baselines_T2.py @@ -0,0 +1,105 @@ +import argparse +import pickle + +from sklearn.feature_extraction.text import TfidfVectorizer +from sklearn.linear_model import LogisticRegression as LR +from quapy.method.aggregative import * +from quapy.method.non_aggregative import MaximumLikelihoodPrevalenceEstimation as MLPE +import quapy.functional as F +from data import * +import os +import constants + + +# LeQua official baselines for task T1A (Binary/Vector) and T1B (Multiclass/Vector) +# ========================================================= + +def baselines(): + yield CC(LR(n_jobs=-1)), "CC" + yield ACC(LR(n_jobs=-1)), "ACC" + yield PCC(LR(n_jobs=-1)), "PCC" + yield PACC(LR(n_jobs=-1)), "PACC" + yield EMQ(CalibratedClassifierCV(LR(), n_jobs=-1)), "SLD" + yield HDy(LR(n_jobs=-1)) if args.task == 'T2A' else OneVsAll(HDy(LR()), n_jobs=-1), "HDy" + # yield MLPE(), "MLPE" + + +def main(args): + + models_path = qp.util.create_if_not_exist(os.path.join(args.modeldir, args.task)) + + path_dev_vectors = os.path.join(args.datadir, 'dev_documents') + path_dev_prevs = os.path.join(args.datadir, 'dev_prevalences.csv') + path_train = os.path.join(args.datadir, 'training_documents.txt') + + qp.environ['SAMPLE_SIZE'] = constants.SAMPLE_SIZE[args.task] + + train = LabelledCollection.load(path_train, load_raw_documents) + tfidf = TfidfVectorizer(lowercase=True, stop_words='english', min_df=4) # TfidfVectorizer(min_df=5) + train.instances = tfidf.fit_transform(train.instances) + nF = train.instances.shape[1] + + print(f'number of classes: {len(train.classes_)}') + print(f'number of training documents: {len(train)}') + print(f'training prevalence: {F.strprev(train.prevalence())}') + print(f'training matrix shape: {train.instances.shape}') + + param_grid = { + 'C': np.logspace(-3, 3, 7), + 'class_weight': ['balanced', None] + } + + # param_grid = { + # 'C': [1], + # 'class_weight': ['balanced'] + # } + + def gen_samples(): + return gen_load_samples(path_dev_vectors, ground_truth_path=path_dev_prevs, return_id=False, + load_fn=load_raw_unlabelled_documents, vectorizer=tfidf) + + for quantifier, q_name in baselines(): + print(f'{q_name}: Model selection') + quantifier = qp.model_selection.GridSearchQ( + quantifier, + param_grid, + sample_size=None, + protocol='gen', + error=qp.error.mae, + refit=False, + verbose=True + ).fit(train, gen_samples) + + print(f'{q_name} got MAE={quantifier.best_score_:.3f} (hyper-params: {quantifier.best_params_})') + + model_path = os.path.join(models_path, q_name+'.'+args.task+'.pkl') + print(f'saving model in {model_path}') + pickle.dump(quantifier.best_model(), open(model_path, 'wb'), protocol=pickle.HIGHEST_PROTOCOL) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='LeQua2022 Task T2A/T2B baselines') + parser.add_argument('task', metavar='TASK', type=str, choices=['T2A', 'T2B'], + help='Task name (T2A, T2B)') + parser.add_argument('datadir', metavar='DATA-PATH', type=str, + help='Path of the directory containing "dev_prevalences.csv", "training_documents.txt", and ' + 'the directory "dev_documents"') + parser.add_argument('modeldir', metavar='MODEL-PATH', type=str, + help='Path where to save the models. ' + 'A subdirectory named will be automatically created.') + args = parser.parse_args() + + if not os.path.exists(args.datadir): + raise FileNotFoundError(f'path {args.datadir} does not exist') + if not os.path.isdir(args.datadir): + raise ValueError(f'path {args.datadir} is not a valid directory') + if not os.path.exists(os.path.join(args.datadir, "dev_prevalences.csv")): + raise FileNotFoundError(f'path {args.datadir} does not contain "dev_prevalences.csv" file') + if not os.path.exists(os.path.join(args.datadir, "training_documents.txt")): + raise FileNotFoundError(f'path {args.datadir} does not contain "training_documents.txt" file') + if not os.path.exists(os.path.join(args.datadir, "dev_documents")): + raise FileNotFoundError(f'path {args.datadir} does not contain "dev_vectors" folder') + + main(args) + + # print('WITHOUT MODEL SELECTION')