import argparse
import quapy as qp
from scripts.data import ResultSubmission
import os
import pickle
from tqdm import tqdm
from scripts.data import gen_load_samples
from glob import glob
from scripts import constants
from regressor import KDEyRegressor, RegressionToSimplex

"""
LeQua2024 prediction script 
"""

def main(args):

    if not args.force and os.path.exists(args.output):
        print(f'prediction file {args.output} already exists! set --force to override')
        return

    # check the number of samples
    nsamples = len(glob(os.path.join(args.samples, f'*.txt')))
    if nsamples not in {constants.DEV_SAMPLES, constants.TEST_SAMPLES}:
        print(f'Warning: The number of samples (.txt) in {args.samples} '
              f'does neither coincide with the expected number of '
              f'dev samples ({constants.DEV_SAMPLES}) nor with the expected number of '
              f'test samples ({constants.TEST_SAMPLES}).')

    # load pickled model
    model = pickle.load(open(args.model, 'rb'))

    # predictions
    predictions = ResultSubmission()
    for sampleid, sample in tqdm(gen_load_samples(args.samples, return_id=True), desc='predicting', total=nsamples):
        predictions.add(sampleid, model.quantify(sample))

    # saving
    qp.util.create_parent_dir(args.output)
    predictions.dump(args.output)


if __name__=='__main__':
    parser = argparse.ArgumentParser(description='LeQua2024 prediction script')
    parser.add_argument('model', metavar='MODEL-PATH', type=str,
                        help='Path of saved model')
    parser.add_argument('samples', metavar='SAMPLES-PATH', type=str,
                        help='Path to the directory containing the samples')
    parser.add_argument('output', metavar='PREDICTIONS-PATH', type=str,
                        help='Path where to store the predictions file')
    parser.add_argument('--force', action='store_true',
                        help='Overrides prediction file if exists')
    args = parser.parse_args()

    if not os.path.exists(args.samples):
        raise FileNotFoundError(f'path {args.samples} does not exist')
    if not os.path.isdir(args.samples):
        raise ValueError(f'path {args.samples} is not a valid directory')

    main(args)