update trainer to handle mT5
This commit is contained in:
parent
26aa0b327a
commit
65407f51fa
|
@ -188,7 +188,9 @@ class Trainer:
|
|||
|
||||
def get_config(self, train_dataloader, eval_dataloader, epochs):
|
||||
return {
|
||||
"model name": self.model.name_or_path,
|
||||
"model name": self.model.name_or_path
|
||||
if not hasattr(self.model, "mt5encoder")
|
||||
else self.model.mt5encoder.name_or_path,
|
||||
"epochs": epochs,
|
||||
"learning rate": self.optimizer.defaults["lr"],
|
||||
"scheduler": self.scheduler_name, # TODO: add scheduler params
|
||||
|
@ -212,7 +214,11 @@ class Trainer:
|
|||
print(f"\t{k}: {v}")
|
||||
|
||||
wandb_logger = wandb.init(
|
||||
project="gfun", entity="andreapdr", config=_config, reinit=True
|
||||
project="gfun",
|
||||
entity="andreapdr",
|
||||
name=f"{_config['model name']} lr: {_config['learning rate']} scheduler: {_config['scheduler']}",
|
||||
config=_config,
|
||||
reinit=True,
|
||||
)
|
||||
|
||||
for epoch in range(epochs):
|
||||
|
|
Loading…
Reference in New Issue