gfun_multimodal/gfun/vgfs/multilingualGen.py

194 lines
6.1 KiB
Python

from os.path import expanduser, join
import torch
import numpy as np
from torchtext.vocab import Vectors
from joblib import Parallel, delayed
from vgfs.viewGen import ViewGen
from vgfs.commons import _normalize, XdotM
from vgfs.learners.svms import FeatureSet2Posteriors
class MultilingualGen(ViewGen):
def __init__(
self,
cached=False,
langs=["en", "it"],
embed_dir="~/embeddings",
n_jobs=-1,
probabilistic=False,
):
print("- init Multilingual View Generating Function")
self.embed_dir = embed_dir
self.langs = langs
self.n_jobs = n_jobs
self.cached = cached
self.vectorizer = None
self.sif = True
self.probabilistic = probabilistic
self.fitted = False
self._init()
def _init(self):
if self.probabilistic:
self.feature2posterior_projector = FeatureSet2Posteriors(
n_jobs=self.n_jobs, verbose=False
)
def fit(self, lX, lY):
"""
Fitting Multilingual View Generating Function consists in
building/extracting the word embedding matrix for
each language;
"""
print("- fitting Multilingual View Generating Function")
self.l_vocab = self.vectorizer.vocabulary()
self.multi_embeddings, self.langs = self._load_embeddings(
self.embed_dir, self.cached
)
if self.probabilistic:
self.feature2posterior_projector.fit(self.transform(lX), lY)
self.fitted = True
return self
def transform(self, lX):
lX = self.vectorizer.transform(lX)
XdotMulti = Parallel(n_jobs=self.n_jobs)(
delayed(XdotM)(lX[lang], self.multi_embeddings[lang], sif=self.sif)
for lang in self.langs
)
lZ = {lang: XdotMulti[i] for i, lang in enumerate(self.langs)}
lZ = _normalize(lZ, l2=True)
if self.probabilistic and self.fitted:
lZ = self.feature2posterior_projector.transform(lZ)
return lZ
def fit_transform(self, lX, lY):
return self.fit(lX, lY).transform(lX)
def _load_embeddings(self, embed_dir, cached):
if "muse" in self.embed_dir.lower():
multi_embeddings = load_MUSEs(
langs=self.langs,
l_vocab=self.vectorizer.vocabulary(),
dir_path=embed_dir,
cached=cached,
)
return multi_embeddings, sorted(multi_embeddings.keys())
def get_config(self):
return {
"name": "Multilingual VGF",
"embed_dir": self.embed_dir,
"langs": self.langs,
"n_jobs": self.n_jobs,
"cached": self.cached,
"sif": self.sif,
"probabilistic": self.probabilistic,
}
def save_vgf(self, model_id):
import pickle
from os.path import join
from os import makedirs
vgf_name = "multilingualGen"
_basedir = join("models", "vgfs", "multilingual")
makedirs(_basedir, exist_ok=True)
_path = join(_basedir, f"{vgf_name}_{model_id}.pkl")
with open(_path, "wb") as f:
pickle.dump(self, f)
return self
def __str__(self):
_str = f"[Multilingual VGF (m)]\n- embed_dir: {self.embed_dir}\n- langs: {self.langs}\n- n_jobs: {self.n_jobs}\n- cached: {self.cached}\n- sif: {self.sif}\n- probabilistic: {self.probabilistic}\n"
return _str
def load_MUSEs(langs, l_vocab, dir_path, cached=False):
dir_path = expanduser(dir_path)
cached_dir = join(dir_path, "cached")
nmax = 50000
l_embeddings = {}
for lang in langs:
embed_path = f"wiki.multi.{lang}.vec"
if cached:
l_embeddings[lang] = Vectors(embed_path, cache=cached_dir)
print(f"-- Loaded cached {lang} embeddings")
else:
(
_embed_matrix,
_,
_,
) = _load_vec(join(dir_path, embed_path), nmax)
l_embeddings[lang] = _embed_matrix
print(f"-- Loaded {nmax} {lang} embeddings")
# print("-- Extracting embeddings")
l_embeddings = extract(l_vocab, l_embeddings)
return l_embeddings
def _load_vec(emb_path, nmax=50000):
import io
import numpy as np
vectors = []
word2id = {}
with io.open(emb_path, "r", encoding="utf-8", newline="\n", errors="ignore") as f:
next(f)
for i, line in enumerate(f):
word, vect = line.rstrip().split(" ", 1)
vect = np.fromstring(vect, sep=" ")
assert word not in word2id, "word found twice"
vectors.append(vect)
word2id[word] = len(word2id)
if len(word2id) == nmax:
break
id2word = {v: k for k, v in word2id.items()}
embeddings = np.vstack(vectors)
return embeddings, id2word, word2id
def extract(l_voc, l_embeddings):
"""
Reindex pretrained loaded embedding in order to match indexes
assigned by scikit vectorizer. Such indexes are consistent with
those used by Word Class Embeddings (since we deploy the same vectorizer)
:param lVoc: dict {lang : {word : id}}
:return: torch embedding matrix of extracted embeddings i.e., words in lVoc
"""
l_extracted = {}
for lang, words in l_voc.items():
source_id, target_id = reindex(words, l_embeddings[lang].stoi)
extraction = torch.zeros((len(words), l_embeddings[lang].vectors.shape[-1]))
extraction[source_id] = l_embeddings[lang].vectors[target_id]
l_extracted[lang] = extraction
return l_extracted
def reindex(vectorizer_words, pretrained_word2index):
if isinstance(vectorizer_words, dict):
vectorizer_words = list(
zip(*sorted(vectorizer_words.items(), key=lambda x: x[1]))
)[0]
source_idx, target_idx = [], []
for i, word in enumerate(vectorizer_words):
if word not in pretrained_word2index:
continue
j = pretrained_word2index[word]
source_idx.append(i)
target_idx.append(j)
source_idx = np.asarray(source_idx)
target_idx = np.asarray(target_idx)
return source_idx, target_idx