import zipfile

import pandas as pd
import os
from os.path import join

import quapy as qp
from scripts.data import load_vector_documents

from quapy.data import LabelledCollection
from quapy.protocol import AbstractProtocol
from quapy.util import download_file_if_not_exists

LEQUA2024_TASKS = ['T1', 'T2', 'T3', 'T4']

LEQUA2024_ZENODO = 'https://zenodo.org/record/11091067'  # v2, no ground truth for test yet


class LabelledCollectionsFromDir(AbstractProtocol):

    def __init__(self, path_dir:str, ground_truth_path:str, load_fn):
        self.path_dir = path_dir
        self.load_fn = load_fn
        self.true_prevs = pd.read_csv(ground_truth_path, index_col=0)

    def __call__(self):
        for id, prevalence in self.true_prevs.iterrows():
            collection_path = os.path.join(self.path_dir, f'{id}.txt')
            lc = LabelledCollection.load(path=collection_path, loader_func=self.load_fn)
            yield lc


def fetch_lequa2024(task, data_home=None, merge_T3=False):

    from quapy.data._lequa2022 import SamplesFromDir

    assert task in LEQUA2024_TASKS, \
        f'Unknown task {task}. Valid ones are {LEQUA2024_TASKS}'

    if data_home is None:
        data_home = qp.util.get_quapy_home()

    lequa_dir = data_home

    URL_TRAINDEV=f'{LEQUA2024_ZENODO}/files/{task}.train_dev.zip'
    URL_TEST=f'{LEQUA2024_ZENODO}/files/{task}.test.zip'
    # URL_TEST_PREV=f'{LEQUA2024_ZENODO}/files/{task}.test_prevalences.zip'

    lequa_dir = join(data_home, 'lequa2024')
    os.makedirs(lequa_dir, exist_ok=True)

    def download_unzip_and_remove(unzipped_path, url):
        tmp_path = join(lequa_dir, task + '_tmp.zip')
        download_file_if_not_exists(url, tmp_path)
        with zipfile.ZipFile(tmp_path) as file:
            file.extractall(unzipped_path)
        os.remove(tmp_path)

    if not os.path.exists(join(lequa_dir, task)):
        download_unzip_and_remove(lequa_dir, URL_TRAINDEV)
        download_unzip_and_remove(lequa_dir, URL_TEST)
    #     download_unzip_and_remove(lequa_dir, URL_TEST_PREV)

    load_fn = load_vector_documents

    val_samples_path = join(lequa_dir, task, 'public', 'dev_samples')
    val_true_prev_path = join(lequa_dir, task, 'public', 'dev_prevalences.txt')
    val_gen = SamplesFromDir(val_samples_path, val_true_prev_path, load_fn=load_fn)

    # test_samples_path = join(lequa_dir, task, 'public', 'test_samples')
    # test_true_prev_path = join(lequa_dir, task, 'public', 'test_prevalences.txt')
    # test_gen = SamplesFromDir(test_samples_path, test_true_prev_path, load_fn=load_fn)
    test_gen = None

    if task != 'T3':
        tr_path = join(lequa_dir, task, 'public', 'training_data.txt')
        train = LabelledCollection.load(tr_path, loader_func=load_fn)
        return train, val_gen, test_gen
    else:
        training_samples_path = join(lequa_dir, task, 'public', 'training_samples')
        training_true_prev_path = join(lequa_dir, task, 'public', 'training_prevalences.txt')
        train_gen = LabelledCollectionsFromDir(training_samples_path, training_true_prev_path, load_fn=load_fn)
        if merge_T3:
            train = LabelledCollection.join(*list(train_gen()))
            return train, val_gen, test_gen
        else:
            return train_gen, val_gen, test_gen