update trainer to handle mT5

This commit is contained in:
Andrea Pedrotti 2023-03-15 11:47:17 +01:00
parent 26aa0b327a
commit 65407f51fa
1 changed files with 8 additions and 2 deletions

View File

@ -188,7 +188,9 @@ class Trainer:
def get_config(self, train_dataloader, eval_dataloader, epochs):
return {
"model name": self.model.name_or_path,
"model name": self.model.name_or_path
if not hasattr(self.model, "mt5encoder")
else self.model.mt5encoder.name_or_path,
"epochs": epochs,
"learning rate": self.optimizer.defaults["lr"],
"scheduler": self.scheduler_name, # TODO: add scheduler params
@ -212,7 +214,11 @@ class Trainer:
print(f"\t{k}: {v}")
wandb_logger = wandb.init(
project="gfun", entity="andreapdr", config=_config, reinit=True
project="gfun",
entity="andreapdr",
name=f"{_config['model name']} lr: {_config['learning rate']} scheduler: {_config['scheduler']}",
config=_config,
reinit=True,
)
for epoch in range(epochs):