diff --git a/gfun/vgfs/commons.py b/gfun/vgfs/commons.py index e6527cd..1787c0b 100644 --- a/gfun/vgfs/commons.py +++ b/gfun/vgfs/commons.py @@ -188,7 +188,9 @@ class Trainer: def get_config(self, train_dataloader, eval_dataloader, epochs): return { - "model name": self.model.name_or_path, + "model name": self.model.name_or_path + if not hasattr(self.model, "mt5encoder") + else self.model.mt5encoder.name_or_path, "epochs": epochs, "learning rate": self.optimizer.defaults["lr"], "scheduler": self.scheduler_name, # TODO: add scheduler params @@ -212,7 +214,11 @@ class Trainer: print(f"\t{k}: {v}") wandb_logger = wandb.init( - project="gfun", entity="andreapdr", config=_config, reinit=True + project="gfun", + entity="andreapdr", + name=f"{_config['model name']} lr: {_config['learning rate']} scheduler: {_config['scheduler']}", + config=_config, + reinit=True, ) for epoch in range(epochs):