diff --git a/src/view_generators.py b/src/view_generators.py index a690e8f..05c5263 100644 --- a/src/view_generators.py +++ b/src/view_generators.py @@ -363,6 +363,8 @@ class RecurrentGen(ViewGen): :param lX: dict {lang: indexed documents} :return: documents projected to the common latent space. """ + if self.zero_shot: + lX = self.zero_shot_experiments(lX) data = {} for lang in lX.keys(): indexed = index(data=lX[lang], @@ -381,6 +383,12 @@ class RecurrentGen(ViewGen): def fit_transform(self, lX, ly): return self.fit(lX, ly).transform(lX) + def zero_shot_experiments(self, lX): + for lang in sorted(lX.keys()): + if lang not in self.train_langs: + lX.pop(lang) + return lX + def set_zero_shot(self, val: bool): self.zero_shot = val return @@ -458,6 +466,8 @@ class BertGen(ViewGen): :param lX: dict {lang: indexed documents} :return: documents projected to the common latent space. """ + if self.zero_shot: + lX = self.zero_shot_experiments(lX) data = tokenize(lX, max_len=512) self.model.to('cuda' if self.gpus else 'cpu') self.model.eval() @@ -468,6 +478,12 @@ class BertGen(ViewGen): # we can assume that we have already indexed data for transform() since we are first calling fit() return self.fit(lX, ly).transform(lX) + def zero_shot_experiments(self, lX): + for lang in sorted(lX.keys()): + if lang not in self.train_langs: + lX.pop(lang) + return lX + def set_zero_shot(self, val: bool): self.zero_shot = val return