From d2736912239c224f2392fbf3dc8dbed1eed52f80 Mon Sep 17 00:00:00 2001 From: andrea Date: Thu, 11 Feb 2021 12:44:32 +0100 Subject: [PATCH] running comparison --- src/main_gFun.py | 6 +++--- src/models/mBert.py | 3 ++- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/src/main_gFun.py b/src/main_gFun.py index a6f982b..2f11cb7 100644 --- a/src/main_gFun.py +++ b/src/main_gFun.py @@ -8,7 +8,7 @@ from util.common import * from util.parser_options import * import os -os.environ["CUDA_VISIBLE_DEVICES"] = "1" +os.environ["CUDA_VISIBLE_DEVICES"] = "0" if __name__ == '__main__': @@ -108,8 +108,8 @@ if __name__ == '__main__': """ View generator (-B): generates document embedding via mBERT model. """ - op.bert_path = '/home/andreapdr/funneling_pdr/hug_checkpoint/mBERT-rcv1-2_run0' # TODO DEBUG - op.bert_path = None + op.bert_path = '/home/andreapdr/gfun/hug_checkpoint/pytorch_model.bin' + # op.bert_path = None mbert = MBertEmbedder(path_to_model=op.bert_path, nC=data.num_categories(), options=op) if op.allprob: diff --git a/src/models/mBert.py b/src/models/mBert.py index b8ef5ee..a1fb70e 100644 --- a/src/models/mBert.py +++ b/src/models/mBert.py @@ -100,7 +100,8 @@ class ExtractorDataset(Dataset): def get_model(n_out): print('# Initializing model ...') - model = BertForSequenceClassification.from_pretrained('bert-base-multilingual-cased', num_labels=n_out) + model = BertForSequenceClassification.from_pretrained('bert-base-multilingual-cased', num_labels=n_out, + output_hidden_states=True) return model