Compare commits

...

10 Commits

Author SHA1 Message Date
andrea a4f74dcf41 fixed pl early stop --> patience was consumed if actual_monitor == best_monitor. Set policy to greater or equal. 2021-02-11 18:37:34 +01:00
andrea a1c4247e17 fixed common after problematic merge 2021-02-05 11:22:30 +01:00
andrea 1ac850630b Merge branch 'devel' 2021-02-05 11:07:40 +01:00
andrea 59146f0dda fixed typos + n_jobs across code (still missing one wrt brach 'rsc') 2021-02-04 16:52:05 +01:00
andrea ec050dce7b typos 2021-02-03 12:30:44 +01:00
andrea e78b1f8a30 merged devel 2021-02-03 11:20:08 +01:00
andrea b98821d3ff running comparison with refactor branch 2021-01-29 14:56:20 +01:00
andrea 5405f60bd0 running comparison with refactor branch 2021-01-29 14:50:34 +01:00
andrea 66952820f9 running comparison with refactor branch 2021-01-29 12:30:31 +01:00
andrea 091101b39d running comparison with refactor branch 2021-01-29 11:37:42 +01:00
2 changed files with 13 additions and 5 deletions

View File

@ -49,16 +49,16 @@ def main(args):
if args.bert_embedder:
bertEmbedder = BertGen(multilingualIndex, batch_size=args.batch_bert, nepochs=args.nepochs_bert,
patience=args.patience_bert, gpus=args.gpus, n_jobs=args.n_jobs)
bertEmbedder.transform(lX)
embedder_list.append(bertEmbedder)
# Init DocEmbedderList (i.e., first-tier learners or view generators) and metaclassifier
docEmbedders = DocEmbedderList(embedder_list=embedder_list, probabilistic=True)
meta = MetaClassifier(meta_learner=get_learner(calibrate=False, kernel='rbf'),
meta_parameters=get_params(optimc=args.optimc))
meta_parameters=get_params(optimc=args.optimc),
n_jobs=args.n_jobs)
# Init Funnelling Architecture
gfun = Funnelling(first_tier=docEmbedders, meta_classifier=meta)
gfun = Funnelling(first_tier=docEmbedders, meta_classifier=meta, n_jobs=args.n_jobs)
# Training ---------------------------------------
print('\n[Training Generalized Funnelling]')
@ -71,7 +71,7 @@ def main(args):
print('\n[Testing Generalized Funnelling]')
time_te = time.time()
ly_ = gfun.predict(lXte)
l_eval = evaluate(ly_true=lyte, ly_pred=ly_)
l_eval = evaluate(ly_true=lyte, ly_pred=ly_, n_jobs=args.n_jobs)
time_te = round(time.time() - time_te, 3)
print(f'Testing completed in {time_te} seconds!')

View File

@ -18,6 +18,7 @@ This module contains the view generators that take care of computing the view sp
from abc import ABC, abstractmethod
# from time import time
import torch
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
@ -241,6 +242,10 @@ class RecurrentGen(ViewGen):
self.logger = TensorBoardLogger(save_dir='../tb_logs', name='rnn', default_hp_metric=False)
self.early_stop_callback = EarlyStopping(monitor='val-macroF1', min_delta=0.00,
patience=self.patience, verbose=False, mode='max')
# modifying EarlyStopping global var in order to compute >= with respect to the best score
self.early_stop_callback.mode_dict['max'] = torch.ge
self.lr_monitor = LearningRateMonitor(logging_interval='epoch')
def _init_model(self):
@ -348,6 +353,9 @@ class BertGen(ViewGen):
self.early_stop_callback = EarlyStopping(monitor='val-macroF1', min_delta=0.00,
patience=self.patience, verbose=False, mode='max')
# modifying EarlyStopping global var in order to compute >= with respect to the best score
self.early_stop_callback.mode_dict['max'] = torch.ge
def _init_model(self):
output_size = self.multilingualIndex.get_target_dim()
return BertModel(output_size=output_size, stored_path=self.stored_path, gpus=self.gpus)
@ -361,7 +369,7 @@ class BertGen(ViewGen):
:param ly: dict {lang: target vectors}
:return: self.
"""
print('# Fitting BertGen (M)...')
print('# Fitting BertGen (B)...')
create_if_not_exist(self.logger.save_dir)
self.multilingualIndex.train_val_split(val_prop=0.2, max_val=2000, seed=1)
bertDataModule = BertDataModule(self.multilingualIndex, batchsize=self.batch_size, max_len=512)