# TODO: this should be a instance of an abstract MultilingualDataset

from abc import ABC, abstractmethod
from scipy.sparse import issparse
from os.path import join, expanduser
import pickle
import re
import numpy as np
from tqdm import tqdm


class NewMultilingualDataset(ABC):
    @abstractmethod
    def get_training(self):
        pass

    @abstractmethod
    def get_validation(self):
        pass

    @abstractmethod
    def get_test(self):
        pass

    @abstractmethod
    def mask_numbers(self):
        pass

    @abstractmethod
    def save(self):
        pass

    @abstractmethod
    def load(self):
        pass


# class RcvMultilingualDataset(MultilingualDataset):
class RcvMultilingualDataset:
    def __init__(
        self,
        run="0",
    ):
        self.dataset_name = "rcv1-2"
        self.dataset_path = expanduser(
            f"~/datasets/rcv1-2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run{run}.pickle"
        )

    def load(self):
        import pickle

        data = pickle.load(open(self.dataset_path, "rb"))
        return self


class MultilingualDataset:
    """
    A multilingual dataset is a dictionary of training and test documents indexed by language code.
    Train and test sets are represented as tuples of the type (X,Y,ids), where X is a matrix representation of the
    documents (e.g., a document-by-term sparse csr_matrix), Y is a document-by-label binary np.array indicating the
    labels of each document, and ids is a list of document-identifiers from the original collection.
    """

    def __init__(self, dataset_name):
        self.dataset_name = dataset_name
        self.multiling_dataset = {}
        print(f"[Init Multilingual Dataset: {self.dataset_name}]")

    def add(self, lang, Xtr, Ytr, Xte, Yte, tr_ids=None, te_ids=None):
        self.multiling_dataset[lang] = ((Xtr, Ytr, tr_ids), (Xte, Yte, te_ids))

    def save(self, file):
        self.sort_indexes()
        pickle.dump(self, open(file, "wb"), pickle.HIGHEST_PROTOCOL)
        return self

    def __getitem__(self, item):
        if item in self.langs():
            return self.multiling_dataset[item]
        return None

    @classmethod
    def load(cls, file):
        data = pickle.load(open(file, "rb"))
        data.sort_indexes()
        return data

    @classmethod
    def load_ids(cls, file):
        data = pickle.load(open(file, "rb"))
        tr_ids = {
            lang: tr_ids
            for (lang, ((_, _, tr_ids), (_, _, _))) in data.multiling_dataset.items()
        }
        te_ids = {
            lang: te_ids
            for (lang, ((_, _, _), (_, _, te_ids))) in data.multiling_dataset.items()
        }
        return tr_ids, te_ids

    def sort_indexes(self):
        for lang, ((Xtr, _, _), (Xte, _, _)) in self.multiling_dataset.items():
            if issparse(Xtr):
                Xtr.sort_indices()
            if issparse(Xte):
                Xte.sort_indices()

    def set_view(self, categories=None, languages=None):
        if categories is not None:
            if isinstance(categories, int):
                categories = np.array([categories])
            elif isinstance(categories, list):
                categories = np.array(categories)
            self.categories_view = categories
        if languages is not None:
            self.languages_view = languages

    def training(self, mask_numbers=False, target_as_csr=False):
        return self.lXtr(mask_numbers), self.lYtr(as_csr=target_as_csr)

    def test(self, mask_numbers=False, target_as_csr=False):
        return self.lXte(mask_numbers), self.lYte(as_csr=target_as_csr)

    def lXtr(self, mask_numbers=False):
        proc = lambda x: _mask_numbers(x) if mask_numbers else x
        # return {lang: Xtr for (lang, ((Xtr, _, _), _)) in self.multiling_dataset.items() if lang in self.langs()}
        return {
            lang: proc(Xtr)
            for (lang, ((Xtr, _, _), _)) in self.multiling_dataset.items()
            if lang in self.langs()
        }

    def lXte(self, mask_numbers=False):
        proc = lambda x: _mask_numbers(x) if mask_numbers else x
        # return {lang: Xte for (lang, (_, (Xte, _, _))) in self.multiling_dataset.items() if lang in self.langs()}
        return {
            lang: proc(Xte)
            for (lang, (_, (Xte, _, _))) in self.multiling_dataset.items()
            if lang in self.langs()
        }

    def lYtr(self, as_csr=False):
        lY = {
            lang: self.cat_view(Ytr)
            for (lang, ((_, Ytr, _), _)) in self.multiling_dataset.items()
            if lang in self.langs()
        }
        if as_csr:
            lY = {l: csr_matrix(Y) for l, Y in lY.items()}
        return lY

    def lYte(self, as_csr=False):
        lY = {
            lang: self.cat_view(Yte)
            for (lang, (_, (_, Yte, _))) in self.multiling_dataset.items()
            if lang in self.langs()
        }
        if as_csr:
            lY = {l: csr_matrix(Y) for l, Y in lY.items()}
        return lY

    def cat_view(self, Y):
        if hasattr(self, "categories_view"):
            return Y[:, self.categories_view]
        else:
            return Y

    def langs(self):
        if hasattr(self, "languages_view"):
            langs = self.languages_view
        else:
            langs = sorted(self.multiling_dataset.keys())
        return langs

    def num_categories(self):
        return self.lYtr()[self.langs()[0]].shape[1]

    def show_dimensions(self):
        def shape(X):
            return X.shape if hasattr(X, "shape") else len(X)

        for lang, (
            (Xtr, Ytr, IDtr),
            (Xte, Yte, IDte),
        ) in self.multiling_dataset.items():
            if lang not in self.langs():
                continue
            print(
                "Lang {}, Xtr={}, ytr={}, Xte={}, yte={}".format(
                    lang,
                    shape(Xtr),
                    self.cat_view(Ytr).shape,
                    shape(Xte),
                    self.cat_view(Yte).shape,
                )
            )

    def show_category_prevalences(self):
        nC = self.num_categories()
        accum_tr = np.zeros(nC, dtype=np.int)
        accum_te = np.zeros(nC, dtype=np.int)
        in_langs = np.zeros(
            nC, dtype=np.int
        )  # count languages with at least one positive example (per category)
        for lang, (
            (Xtr, Ytr, IDtr),
            (Xte, Yte, IDte),
        ) in self.multiling_dataset.items():
            if lang not in self.langs():
                continue
            prev_train = np.sum(self.cat_view(Ytr), axis=0)
            prev_test = np.sum(self.cat_view(Yte), axis=0)
            accum_tr += prev_train
            accum_te += prev_test
            in_langs += (prev_train > 0) * 1
            print(lang + "-train", prev_train)
            print(lang + "-test", prev_test)
        print("all-train", accum_tr)
        print("all-test", accum_te)

        return accum_tr, accum_te, in_langs

    def set_labels(self, labels):
        self.labels = labels

    def reduce_data(self, langs=["it", "en"], maxn=50):
        print(f"- Reducing data: {langs} with max {maxn} documents...\n")
        self.set_view(languages=langs)

        data = {
            lang: self._reduce(data, maxn)
            for lang, data in self.multiling_dataset.items()
            if lang in langs
        }
        self.multiling_dataset = data
        return self

    def _reduce(self, multilingual_dataset, maxn):
        new_data = []
        for split in multilingual_dataset:
            docs, labels, ids = split
            new_data.append((docs[:maxn], labels[:maxn], ids[:maxn]))
        return new_data


def _mask_numbers(data):
    mask_moredigit = re.compile(r"\s[\+-]?\d{5,}([\.,]\d*)*\b")
    mask_4digit = re.compile(r"\s[\+-]?\d{4}([\.,]\d*)*\b")
    mask_3digit = re.compile(r"\s[\+-]?\d{3}([\.,]\d*)*\b")
    mask_2digit = re.compile(r"\s[\+-]?\d{2}([\.,]\d*)*\b")
    mask_1digit = re.compile(r"\s[\+-]?\d{1}([\.,]\d*)*\b")
    masked = []
    for text in tqdm(data, desc="masking numbers"):
        text = " " + text
        text = mask_moredigit.sub(" MoreDigitMask", text)
        text = mask_4digit.sub(" FourDigitMask", text)
        text = mask_3digit.sub(" ThreeDigitMask", text)
        text = mask_2digit.sub(" TwoDigitMask", text)
        text = mask_1digit.sub(" OneDigitMask", text)
        masked.append(text.replace(".", "").replace(",", "").strip())
    return masked


if __name__ == "__main__":
    DATAPATH = expanduser(
        "~/datasets/rcv1-2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle"
    )
    print(DATAPATH)
    dataset = MultilingualDataset().load(DATAPATH)
    print(dataset.show_dimensions())