Compare commits
No commits in common. "master" and "devel" have entirely different histories.
8
main.py
8
main.py
|
|
@ -49,16 +49,16 @@ def main(args):
|
||||||
if args.bert_embedder:
|
if args.bert_embedder:
|
||||||
bertEmbedder = BertGen(multilingualIndex, batch_size=args.batch_bert, nepochs=args.nepochs_bert,
|
bertEmbedder = BertGen(multilingualIndex, batch_size=args.batch_bert, nepochs=args.nepochs_bert,
|
||||||
patience=args.patience_bert, gpus=args.gpus, n_jobs=args.n_jobs)
|
patience=args.patience_bert, gpus=args.gpus, n_jobs=args.n_jobs)
|
||||||
|
bertEmbedder.transform(lX)
|
||||||
embedder_list.append(bertEmbedder)
|
embedder_list.append(bertEmbedder)
|
||||||
|
|
||||||
# Init DocEmbedderList (i.e., first-tier learners or view generators) and metaclassifier
|
# Init DocEmbedderList (i.e., first-tier learners or view generators) and metaclassifier
|
||||||
docEmbedders = DocEmbedderList(embedder_list=embedder_list, probabilistic=True)
|
docEmbedders = DocEmbedderList(embedder_list=embedder_list, probabilistic=True)
|
||||||
meta = MetaClassifier(meta_learner=get_learner(calibrate=False, kernel='rbf'),
|
meta = MetaClassifier(meta_learner=get_learner(calibrate=False, kernel='rbf'),
|
||||||
meta_parameters=get_params(optimc=args.optimc),
|
meta_parameters=get_params(optimc=args.optimc))
|
||||||
n_jobs=args.n_jobs)
|
|
||||||
|
|
||||||
# Init Funnelling Architecture
|
# Init Funnelling Architecture
|
||||||
gfun = Funnelling(first_tier=docEmbedders, meta_classifier=meta, n_jobs=args.n_jobs)
|
gfun = Funnelling(first_tier=docEmbedders, meta_classifier=meta)
|
||||||
|
|
||||||
# Training ---------------------------------------
|
# Training ---------------------------------------
|
||||||
print('\n[Training Generalized Funnelling]')
|
print('\n[Training Generalized Funnelling]')
|
||||||
|
|
@ -71,7 +71,7 @@ def main(args):
|
||||||
print('\n[Testing Generalized Funnelling]')
|
print('\n[Testing Generalized Funnelling]')
|
||||||
time_te = time.time()
|
time_te = time.time()
|
||||||
ly_ = gfun.predict(lXte)
|
ly_ = gfun.predict(lXte)
|
||||||
l_eval = evaluate(ly_true=lyte, ly_pred=ly_, n_jobs=args.n_jobs)
|
l_eval = evaluate(ly_true=lyte, ly_pred=ly_)
|
||||||
time_te = round(time.time() - time_te, 3)
|
time_te = round(time.time() - time_te, 3)
|
||||||
print(f'Testing completed in {time_te} seconds!')
|
print(f'Testing completed in {time_te} seconds!')
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -18,7 +18,6 @@ This module contains the view generators that take care of computing the view sp
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
# from time import time
|
# from time import time
|
||||||
|
|
||||||
import torch
|
|
||||||
from pytorch_lightning import Trainer
|
from pytorch_lightning import Trainer
|
||||||
from pytorch_lightning.loggers import TensorBoardLogger
|
from pytorch_lightning.loggers import TensorBoardLogger
|
||||||
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
|
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
|
||||||
|
|
@ -242,10 +241,6 @@ class RecurrentGen(ViewGen):
|
||||||
self.logger = TensorBoardLogger(save_dir='../tb_logs', name='rnn', default_hp_metric=False)
|
self.logger = TensorBoardLogger(save_dir='../tb_logs', name='rnn', default_hp_metric=False)
|
||||||
self.early_stop_callback = EarlyStopping(monitor='val-macroF1', min_delta=0.00,
|
self.early_stop_callback = EarlyStopping(monitor='val-macroF1', min_delta=0.00,
|
||||||
patience=self.patience, verbose=False, mode='max')
|
patience=self.patience, verbose=False, mode='max')
|
||||||
|
|
||||||
# modifying EarlyStopping global var in order to compute >= with respect to the best score
|
|
||||||
self.early_stop_callback.mode_dict['max'] = torch.ge
|
|
||||||
|
|
||||||
self.lr_monitor = LearningRateMonitor(logging_interval='epoch')
|
self.lr_monitor = LearningRateMonitor(logging_interval='epoch')
|
||||||
|
|
||||||
def _init_model(self):
|
def _init_model(self):
|
||||||
|
|
@ -353,9 +348,6 @@ class BertGen(ViewGen):
|
||||||
self.early_stop_callback = EarlyStopping(monitor='val-macroF1', min_delta=0.00,
|
self.early_stop_callback = EarlyStopping(monitor='val-macroF1', min_delta=0.00,
|
||||||
patience=self.patience, verbose=False, mode='max')
|
patience=self.patience, verbose=False, mode='max')
|
||||||
|
|
||||||
# modifying EarlyStopping global var in order to compute >= with respect to the best score
|
|
||||||
self.early_stop_callback.mode_dict['max'] = torch.ge
|
|
||||||
|
|
||||||
def _init_model(self):
|
def _init_model(self):
|
||||||
output_size = self.multilingualIndex.get_target_dim()
|
output_size = self.multilingualIndex.get_target_dim()
|
||||||
return BertModel(output_size=output_size, stored_path=self.stored_path, gpus=self.gpus)
|
return BertModel(output_size=output_size, stored_path=self.stored_path, gpus=self.gpus)
|
||||||
|
|
@ -369,7 +361,7 @@ class BertGen(ViewGen):
|
||||||
:param ly: dict {lang: target vectors}
|
:param ly: dict {lang: target vectors}
|
||||||
:return: self.
|
:return: self.
|
||||||
"""
|
"""
|
||||||
print('# Fitting BertGen (B)...')
|
print('# Fitting BertGen (M)...')
|
||||||
create_if_not_exist(self.logger.save_dir)
|
create_if_not_exist(self.logger.save_dir)
|
||||||
self.multilingualIndex.train_val_split(val_prop=0.2, max_val=2000, seed=1)
|
self.multilingualIndex.train_val_split(val_prop=0.2, max_val=2000, seed=1)
|
||||||
bertDataModule = BertDataModule(self.multilingualIndex, batchsize=self.batch_size, max_len=512)
|
bertDataModule = BertDataModule(self.multilingualIndex, batchsize=self.batch_size, max_len=512)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue