From 323749178b6949890c9b2b9cab6fc6a6da35e2e0 Mon Sep 17 00:00:00 2001 From: Alejandro Moreo Date: Wed, 14 Feb 2024 18:50:53 +0100 Subject: [PATCH] adding some preliminary baselines... --- LeQua2024/_lequa2024.py | 83 +++++++++++++++++++++++++ LeQua2024/baselines.py | 123 +++++++++++++++++++++++++++++++++++++ LeQua2024/predict.py | 58 +++++++++++++++++ LeQua2024/run_baselines.sh | 35 +++++++++++ 4 files changed, 299 insertions(+) create mode 100644 LeQua2024/_lequa2024.py create mode 100644 LeQua2024/baselines.py create mode 100644 LeQua2024/predict.py create mode 100755 LeQua2024/run_baselines.sh diff --git a/LeQua2024/_lequa2024.py b/LeQua2024/_lequa2024.py new file mode 100644 index 0000000..5e414c5 --- /dev/null +++ b/LeQua2024/_lequa2024.py @@ -0,0 +1,83 @@ +from typing import Tuple, Union +import pandas as pd +import numpy as np +import os +from os.path import join + +from scripts.data import load_vector_documents + +from quapy.data import LabelledCollection +from quapy.protocol import AbstractProtocol + + +LEQUA2024_TASKS = ['T1', 'T2', 'T3', 'T4'] + + +class LabelledCollectionsFromDir(AbstractProtocol): + + def __init__(self, path_dir:str, ground_truth_path:str, load_fn): + self.path_dir = path_dir + self.load_fn = load_fn + self.true_prevs = pd.read_csv(ground_truth_path, index_col=0) + + def __call__(self): + for id, prevalence in self.true_prevs.iterrows(): + collection_path = os.path.join(self.path_dir, f'{id}.txt') + lc = LabelledCollection.load(path=collection_path, loader_func=self.load_fn) + yield lc + + +def fetch_lequa2024(task, data_home='./data', merge_T3=False): + + from quapy.data._lequa2022 import SamplesFromDir + + assert task in LEQUA2024_TASKS, \ + f'Unknown task {task}. Valid ones are {LEQUA2024_TASKS}' + + # if data_home is None: + # data_home = get_quapy_home() + lequa_dir = data_home + + # URL_TRAINDEV=f'https://zenodo.org/record/6546188/files/{task}.train_dev.zip' + # URL_TEST=f'https://zenodo.org/record/6546188/files/{task}.test.zip' + # URL_TEST_PREV=f'https://zenodo.org/record/6546188/files/{task}.test_prevalences.zip' + + # lequa_dir = join(data_home, 'lequa2024') + # os.makedirs(lequa_dir, exist_ok=True) + + # def download_unzip_and_remove(unzipped_path, url): + # tmp_path = join(lequa_dir, task + '_tmp.zip') + # download_file_if_not_exists(url, tmp_path) + # with zipfile.ZipFile(tmp_path) as file: + # file.extractall(unzipped_path) + # os.remove(tmp_path) + + # if not os.path.exists(join(lequa_dir, task)): + # download_unzip_and_remove(lequa_dir, URL_TRAINDEV) + # download_unzip_and_remove(lequa_dir, URL_TEST) + # download_unzip_and_remove(lequa_dir, URL_TEST_PREV) + + load_fn = load_vector_documents + + val_samples_path = join(lequa_dir, task, 'public', 'dev_samples') + val_true_prev_path = join(lequa_dir, task, 'public', 'dev_prevalences.txt') + val_gen = SamplesFromDir(val_samples_path, val_true_prev_path, load_fn=load_fn) + + test_samples_path = join(lequa_dir, task, 'public', 'test_samples') + test_true_prev_path = join(lequa_dir, task, 'public', 'test_prevalences.txt') + test_gen = SamplesFromDir(test_samples_path, test_true_prev_path, load_fn=load_fn) + + if task != 'T3': + tr_path = join(lequa_dir, task, 'public', 'training_data.txt') + train = LabelledCollection.load(tr_path, loader_func=load_fn) + return train, val_gen, test_gen + else: + training_samples_path = join(lequa_dir, task, 'public', 'training_samples') + training_true_prev_path = join(lequa_dir, task, 'public', 'training_prevalences.txt') + train_gen = LabelledCollectionsFromDir(training_samples_path, training_true_prev_path, load_fn=load_fn) + if merge_T3: + train = LabelledCollection.join(*list(train_gen())) + return train, val_gen, test_gen + else: + return train_gen, val_gen, test_gen + diff --git a/LeQua2024/baselines.py b/LeQua2024/baselines.py new file mode 100644 index 0000000..c285199 --- /dev/null +++ b/LeQua2024/baselines.py @@ -0,0 +1,123 @@ +import argparse +import pickle +import os +from os.path import join +from sklearn.linear_model import LogisticRegression as LR + +from scripts.constants import SAMPLE_SIZE +from LeQua2024._lequa2024 import LEQUA2024_TASKS, fetch_lequa2024 +from quapy.method.aggregative import * +from quapy.method.non_aggregative import MaximumLikelihoodPrevalenceEstimation as MLPE +import quapy.functional as F + + +# LeQua official baselines +# ================================================================================= + +BINARY_TASKS = ['T1', 'T4'] + + +def new_cls(): + return LR(n_jobs=-1) + + +lr_params = { + 'C': np.logspace(-3, 3, 7), + 'class_weight': [None, 'balanced'] +} + +def wrap_params(cls_params:dict, prefix:str): + return {'__'.join([prefix, key]): val for key, val in cls_params.items()} + + + +def baselines(): + + q_params = wrap_params(lr_params, 'classifier') + + # yield CC(new_cls()), "CC", q_params + # yield ACC(new_cls()), "ACC", q_params + # yield PCC(new_cls()), "PCC", q_params + # yield PACC(new_cls()), "PACC", q_params + # yield EMQ(CalibratedClassifierCV(new_cls())), "SLD-Platt", wrap_params(wrap_params(lr_params, 'estimator'), 'classifier') + # yield EMQ(new_cls()), "SLD", q_params + # yield EMQ(new_cls()), "SLD-BCTS", {**q_params, 'recalib': ['bcts'], 'val_split': [5]} + yield MLPE(), "MLPE", None + # if args.task in BINARY_TASKS: + # yield MS2(new_cls()), "MedianSweep2", q_params + # yield KDEyML(new_cls()), "KDEy-ML" + # yield MLPE(), "MLPE" + + +def main(args): + + models_path = qp.util.create_if_not_exist(join('./models', args.task)) + + qp.environ['SAMPLE_SIZE'] = SAMPLE_SIZE[args.task] + + train, gen_val, gen_test = fetch_lequa2024(task=args.task, data_home=args.datadir, merge_T3=True) + + print(f'number of classes: {len(train.classes_)}') + print(f'number of training documents: {len(train)}') + print(f'training prevalence: {F.strprev(train.prevalence())}') + print(f'training matrix shape: {train.instances.shape}') + + for quantifier, q_name, param_grid in baselines(): + + model_path = os.path.join(models_path, q_name + '.pkl') + if os.path.exists(model_path): + print(f'a pickle for {q_name} exists already in {model_path}; skipping!') + continue + + if param_grid is not None: + quantifier = qp.model_selection.GridSearchQ( + quantifier, + param_grid, + protocol=gen_val, + error=qp.error.mrae, + refit=False, + verbose=True, + n_jobs=-1 + ).fit(train) + print(f'{q_name} got MRAE={quantifier.best_score_:.5f} (hyper-params: {quantifier.best_params_})') + quantifier = quantifier.best_model() + else: + quantifier.fit(train) + + + # valid_error = quantifier.best_score_ + + # test_err = qp.evaluation.evaluate(quantifier, protocol=gen_test, error_metric='mrae', verbose=True) + # print(f'method={q_name} got MRAE={test_err:.4f}') + # + # results.append((q_name, valid_error, test_err)) + + + print(f'saving model in {model_path}') + pickle.dump(quantifier, open(model_path, 'wb'), protocol=pickle.HIGHEST_PROTOCOL) + + + # print('\nResults') + # print('Method\tValid-err\ttest-err') + # for q_name, valid_error, test_err in results: + # print(f'{q_name}\t{valid_error:.4}\t{test_err:.4f}') + + +if __name__ == '__main__': + + parser = argparse.ArgumentParser(description='LeQua2024 baselines') + parser.add_argument('task', metavar='TASK', type=str, choices=LEQUA2024_TASKS, + help=f'Code of the task; available ones are {LEQUA2024_TASKS}') + parser.add_argument('datadir', metavar='DATA-PATH', type=str, + help='Path of the directory containing LeQua 2024 data', default='./data') + args = parser.parse_args() + + # def assert_file(filename): + # if not os.path.exists(os.path.join(args.datadir, filename)): + # raise FileNotFoundError(f'path {args.datadir} does not contain "{filename}"') + # + # assert_file('dev_prevalences.txt') + # assert_file('training_data.txt') + # assert_file('dev_samples') + + main(args) diff --git a/LeQua2024/predict.py b/LeQua2024/predict.py new file mode 100644 index 0000000..ee6fc0a --- /dev/null +++ b/LeQua2024/predict.py @@ -0,0 +1,58 @@ +import argparse +import quapy as qp +from scripts.data import ResultSubmission +import os +import pickle +from tqdm import tqdm +from scripts.data import gen_load_samples +from glob import glob +from scripts import constants + +""" +LeQua2024 prediction script +""" + +def main(args): + + if not args.force and os.path.exists(args.output): + print(f'prediction file {args.output} already exists! set --force to override') + return + + # check the number of samples + nsamples = len(glob(os.path.join(args.samples, f'*.txt'))) + if nsamples not in {constants.DEV_SAMPLES, constants.TEST_SAMPLES}: + print(f'Warning: The number of samples (.txt) in {args.samples} does neither coincide with the expected number of ' + f'dev samples ({constants.DEV_SAMPLES}) nor with the expected number of ' + f'test samples ({constants.TEST_SAMPLES}).') + + # load pickled model + model = pickle.load(open(args.model, 'rb')) + + # predictions + predictions = ResultSubmission() + for sampleid, sample in tqdm(gen_load_samples(args.samples, return_id=True), desc='predicting', total=nsamples): + predictions.add(sampleid, model.quantify(sample)) + + # saving + qp.util.create_parent_dir(args.output) + predictions.dump(args.output) + + +if __name__=='__main__': + parser = argparse.ArgumentParser(description='LeQua2022 prediction script') + parser.add_argument('model', metavar='MODEL-PATH', type=str, + help='Path of saved model') + parser.add_argument('samples', metavar='SAMPLES-PATH', type=str, + help='Path to the directory containing the samples') + parser.add_argument('output', metavar='PREDICTIONS-PATH', type=str, + help='Path where to store the predictions file') + parser.add_argument('--force', action='store_true', + help='Overrides prediction file if exists') + args = parser.parse_args() + + if not os.path.exists(args.samples): + raise FileNotFoundError(f'path {args.samples} does not exist') + if not os.path.isdir(args.samples): + raise ValueError(f'path {args.samples} is not a valid directory') + + main(args) diff --git a/LeQua2024/run_baselines.sh b/LeQua2024/run_baselines.sh new file mode 100755 index 0000000..d4fc761 --- /dev/null +++ b/LeQua2024/run_baselines.sh @@ -0,0 +1,35 @@ +#!/bin/bash +set -x + + +# T1: binary (n=2) +# T2: multiclass (n=28) +# T3: ordinal (n=5) +# T4: covariante shift (n=2) + +# -------------------------------------------------------------------------------- +# DEV +# -------------------------------------------------------------------------------- + +mkdir results + +for task in T1 T2 T3 T4 ; do + + echo "" > results/$task.txt + + PYTHONPATH=.:scripts/:.. python3 baselines.py $task data/ + + SAMPLES=data/$task/public/dev_samples + TRUEPREVS=data/$task/public/dev_prevalences.txt + + for pickledmodel in models/$task/*.pkl ; do + model=$(basename "$pickledmodel" .pkl) + + PREDICTIONS=predictions/$task/$model.txt + + PYTHONPATH=.:scripts/:.. python3 predict.py models/$task/$model.pkl $SAMPLES $PREDICTIONS + PYTHONPATH=.:scripts/:.. python3 scripts/evaluate.py $task $TRUEPREVS $PREDICTIONS >> results/$task.txt + done + +done +