diff --git a/main.py b/main.py index 0acf8f1..da2748d 100644 --- a/main.py +++ b/main.py @@ -54,10 +54,11 @@ def main(args): # Init DocEmbedderList (i.e., first-tier learners or view generators) and metaclassifier docEmbedders = DocEmbedderList(embedder_list=embedder_list, probabilistic=True) meta = MetaClassifier(meta_learner=get_learner(calibrate=False, kernel='rbf'), - meta_parameters=get_params(optimc=args.optimc)) + meta_parameters=get_params(optimc=args.optimc), + n_jobs=args.n_jobs) # Init Funnelling Architecture - gfun = Funnelling(first_tier=docEmbedders, meta_classifier=meta) + gfun = Funnelling(first_tier=docEmbedders, meta_classifier=meta, n_jobs=args.n_jobs) # Training --------------------------------------- print('\n[Training Generalized Funnelling]') diff --git a/src/view_generators.py b/src/view_generators.py index af4ee8e..e972ce7 100644 --- a/src/view_generators.py +++ b/src/view_generators.py @@ -361,7 +361,7 @@ class BertGen(ViewGen): :param ly: dict {lang: target vectors} :return: self. """ - print('# Fitting BertGen (M)...') + print('# Fitting BertGen (B)...') create_if_not_exist(self.logger.save_dir) self.multilingualIndex.train_val_split(val_prop=0.2, max_val=2000, seed=1) bertDataModule = BertDataModule(self.multilingualIndex, batchsize=self.batch_size, max_len=512)