update trainer to handle mT5
This commit is contained in:
parent
26aa0b327a
commit
65407f51fa
|
@ -188,7 +188,9 @@ class Trainer:
|
||||||
|
|
||||||
def get_config(self, train_dataloader, eval_dataloader, epochs):
|
def get_config(self, train_dataloader, eval_dataloader, epochs):
|
||||||
return {
|
return {
|
||||||
"model name": self.model.name_or_path,
|
"model name": self.model.name_or_path
|
||||||
|
if not hasattr(self.model, "mt5encoder")
|
||||||
|
else self.model.mt5encoder.name_or_path,
|
||||||
"epochs": epochs,
|
"epochs": epochs,
|
||||||
"learning rate": self.optimizer.defaults["lr"],
|
"learning rate": self.optimizer.defaults["lr"],
|
||||||
"scheduler": self.scheduler_name, # TODO: add scheduler params
|
"scheduler": self.scheduler_name, # TODO: add scheduler params
|
||||||
|
@ -212,7 +214,11 @@ class Trainer:
|
||||||
print(f"\t{k}: {v}")
|
print(f"\t{k}: {v}")
|
||||||
|
|
||||||
wandb_logger = wandb.init(
|
wandb_logger = wandb.init(
|
||||||
project="gfun", entity="andreapdr", config=_config, reinit=True
|
project="gfun",
|
||||||
|
entity="andreapdr",
|
||||||
|
name=f"{_config['model name']} lr: {_config['learning rate']} scheduler: {_config['scheduler']}",
|
||||||
|
config=_config,
|
||||||
|
reinit=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
for epoch in range(epochs):
|
for epoch in range(epochs):
|
||||||
|
|
Loading…
Reference in New Issue