from os.path import expanduser, join

import torch
import numpy as np
from torchtext.vocab import Vectors
from joblib import Parallel, delayed
from gfun.vgfs.viewGen import ViewGen
from gfun.vgfs.commons import _normalize, XdotM
from gfun.vgfs.learners.svms import FeatureSet2Posteriors


class MultilingualGen(ViewGen):
    def __init__(
        self,
        cached=False,
        langs=["en", "it"],
        embed_dir="~/embeddings",
        n_jobs=-1,
        probabilistic=False,
    ):
        print("- init Multilingual View Generating Function")
        self.embed_dir = embed_dir
        self.langs = langs
        self.n_jobs = n_jobs
        self.cached = cached
        self.vectorizer = None
        self.sif = True
        self.probabilistic = probabilistic
        self.fitted = False
        self._init()

    def _init(self):
        if self.probabilistic:
            self.feature2posterior_projector = FeatureSet2Posteriors(
                n_jobs=self.n_jobs, verbose=False
            )

    def fit(self, lX, lY):
        """
        Fitting Multilingual View Generating Function consists in
        building/extracting the word embedding matrix for
        each language;
        """
        print("- fitting Multilingual View Generating Function")
        self.l_vocab = self.vectorizer.vocabulary()
        self.multi_embeddings, self.langs = self._load_embeddings(
            self.embed_dir, self.cached
        )

        if self.probabilistic:
            self.feature2posterior_projector.fit(self.transform(lX), lY)

        self.fitted = True

        return self

    def transform(self, lX):
        lX = self.vectorizer.transform(lX)

        XdotMulti = Parallel(n_jobs=self.n_jobs)(
            delayed(XdotM)(lX[lang], self.multi_embeddings[lang], sif=self.sif)
            for lang in self.langs
        )
        lZ = {lang: XdotMulti[i] for i, lang in enumerate(self.langs)}
        lZ = _normalize(lZ, l2=True)
        if self.probabilistic and self.fitted:
            lZ = self.feature2posterior_projector.transform(lZ)
        return lZ

    def fit_transform(self, lX, lY):
        return self.fit(lX, lY).transform(lX)

    def _load_embeddings(self, embed_dir, cached):
        if "muse" in self.embed_dir.lower():
            multi_embeddings = load_MUSEs(
                langs=self.langs,
                l_vocab=self.vectorizer.vocabulary(),
                dir_path=embed_dir,
                cached=cached,
            )
            return multi_embeddings, sorted(multi_embeddings.keys())

    def get_config(self):
        return {
            "name": "Multilingual VGF",
            "embed_dir": self.embed_dir,
            "langs": self.langs,
            "n_jobs": self.n_jobs,
            "cached": self.cached,
            "sif": self.sif,
            "probabilistic": self.probabilistic,
        }

    def save_vgf(self, model_id):
        import pickle
        from os.path import join
        from os import makedirs

        vgf_name = "multilingualGen"
        _basedir = join("models", "vgfs", "multilingual")
        makedirs(_basedir, exist_ok=True)
        _path = join(_basedir, f"{vgf_name}_{model_id}.pkl")
        with open(_path, "wb") as f:
            pickle.dump(self, f)
        return self

    def __str__(self):
        _str = f"[Multilingual VGF (m)]\n- embed_dir: {self.embed_dir}\n- langs: {self.langs}\n- n_jobs: {self.n_jobs}\n- cached: {self.cached}\n- sif: {self.sif}\n- probabilistic: {self.probabilistic}\n"
        return _str


def load_MUSEs(langs, l_vocab, dir_path, cached=False):
    dir_path = expanduser(dir_path)
    cached_dir = join(dir_path, "cached")
    nmax = 50000

    l_embeddings = {}

    for lang in langs:
        embed_path = f"wiki.multi.{lang}.vec"
        if cached:
            l_embeddings[lang] = Vectors(embed_path, cache=cached_dir)
            print(f"-- Loaded cached {lang} embeddings")
        else:
            (
                _embed_matrix,
                _,
                _,
            ) = _load_vec(join(dir_path, embed_path), nmax)
            l_embeddings[lang] = _embed_matrix
            print(f"-- Loaded {nmax} {lang} embeddings")

    # print("-- Extracting embeddings")
    l_embeddings = extract(l_vocab, l_embeddings)

    return l_embeddings


def _load_vec(emb_path, nmax=50000):
    import io

    import numpy as np

    vectors = []
    word2id = {}
    with io.open(emb_path, "r", encoding="utf-8", newline="\n", errors="ignore") as f:
        next(f)
        for i, line in enumerate(f):
            word, vect = line.rstrip().split(" ", 1)
            vect = np.fromstring(vect, sep=" ")
            assert word not in word2id, "word found twice"
            vectors.append(vect)
            word2id[word] = len(word2id)
            if len(word2id) == nmax:
                break
    id2word = {v: k for k, v in word2id.items()}
    embeddings = np.vstack(vectors)
    return embeddings, id2word, word2id


def extract(l_voc, l_embeddings):
    """
    Reindex pretrained loaded embedding in order to match indexes
    assigned by scikit vectorizer. Such indexes are consistent with
    those used by Word Class Embeddings (since we deploy the same vectorizer)
    :param lVoc: dict {lang : {word : id}}
    :return: torch embedding matrix of extracted embeddings i.e., words in lVoc
    """
    l_extracted = {}
    for lang, words in l_voc.items():
        source_id, target_id = reindex(words, l_embeddings[lang].stoi)
        extraction = torch.zeros((len(words), l_embeddings[lang].vectors.shape[-1]))
        extraction[source_id] = l_embeddings[lang].vectors[target_id]
        l_extracted[lang] = extraction
    return l_extracted


def reindex(vectorizer_words, pretrained_word2index):
    if isinstance(vectorizer_words, dict):
        vectorizer_words = list(
            zip(*sorted(vectorizer_words.items(), key=lambda x: x[1]))
        )[0]

    source_idx, target_idx = [], []
    for i, word in enumerate(vectorizer_words):
        if word not in pretrained_word2index:
            continue
        j = pretrained_word2index[word]
        source_idx.append(i)
        target_idx.append(j)
    source_idx = np.asarray(source_idx)
    target_idx = np.asarray(target_idx)
    return source_idx, target_idx