logging via wandb

This commit is contained in:
Andrea Pedrotti 2023-03-07 17:34:25 +01:00
parent 6b7917ca47
commit 84dd1f093e
3 changed files with 52 additions and 56 deletions

View File

@ -28,6 +28,7 @@ class GeneralizedFunnelling:
embed_dir,
n_jobs,
batch_size,
eval_batch_size,
max_length,
lr,
epochs,
@ -59,7 +60,8 @@ class GeneralizedFunnelling:
self.textual_trf_name = textual_transformer_name
self.epochs = epochs
self.lr_transformer = lr
self.batch_size_transformer = batch_size
self.batch_size_trf = batch_size
self.eval_batch_size_trf = eval_batch_size
self.max_length = max_length
self.early_stopping = True
self.patience = patience
@ -148,7 +150,8 @@ class GeneralizedFunnelling:
model_name=self.textual_trf_name,
lr=self.lr_transformer,
epochs=self.epochs,
batch_size=self.batch_size_transformer,
batch_size=self.batch_size_trf,
batch_size_eval=self.eval_batch_size_trf,
max_length=self.max_length,
print_steps=50,
probabilistic=self.probabilistic,
@ -163,10 +166,10 @@ class GeneralizedFunnelling:
visual_trasformer_vgf = VisualTransformerGen(
dataset_name=self.dataset_name,
model_name="vit",
lr=1e-5, # self.lr_visual_transformer,
lr=self.lr_transformer,
epochs=self.epochs,
batch_size=32, # self.batch_size_visual_transformer,
# batch_size_eval=128,
batch_size=self.batch_size_trf,
batch_size_eval=self.eval_batch_size_trf,
probabilistic=self.probabilistic,
evaluate_step=self.evaluate_step,
patience=self.patience,

View File

@ -140,46 +140,50 @@ class Trainer:
else:
raise ValueError(f"Optimizer {optimizer_name} not supported")
def train(self, train_dataloader, eval_dataloader, epochs=10):
wandb.init(
project="gfun",
name="allhere",
# reinit=True,
config={
"vgf": self.vgf_name,
"architecture": self.model.name_or_path,
"learning_rate": self.optimizer.defaults["lr"],
"epochs": epochs,
"train batch size": train_dataloader.batch_size,
"eval batch size": eval_dataloader.batch_size,
"max len": train_dataloader.dataset.X.shape[-1],
"patience": self.earlystopping.patience,
"evaluate every": self.evaluate_steps,
"print eval every": self.print_eval,
"print train steps": self.print_steps,
},
)
def get_config(self, train_dataloader, eval_dataloader, epochs):
return {
"model name": self.model.name_or_path,
"epochs": epochs,
"learning rate": self.optimizer.defaults["lr"],
"train batch size": train_dataloader.batch_size,
"eval batch size": eval_dataloader.batch_size,
"max len": train_dataloader.dataset.X.shape[-1],
"patience": self.earlystopping.patience,
"evaluate every": self.evaluate_steps,
"print eval every": self.print_eval,
"print train steps": self.print_steps,
}
print(
f"""- Training params for {self.experiment_name}:
- epochs: {epochs}
- learning rate: {self.optimizer.defaults['lr']}
- train batch size: {train_dataloader.batch_size}
- eval batch size: {eval_dataloader.batch_size}
- max len: {train_dataloader.dataset.X.shape[-1]}
- patience: {self.earlystopping.patience}
- evaluate every: {self.evaluate_steps}
- print eval every: {self.print_eval}
- print train steps: {self.print_steps}\n"""
def train(self, train_dataloader, eval_dataloader, epochs=10):
_config = self.get_config(train_dataloader, eval_dataloader, epochs)
print(f"- Training params for {self.experiment_name}:")
for k, v in _config.items():
print(f"\t{k}: {v}")
wandb_logger = wandb.init(
project="gfun", entity="andreapdr", config=_config, reinit=True
)
for epoch in range(epochs):
self.train_epoch(train_dataloader, epoch)
train_loss = self.train_epoch(train_dataloader, epoch)
wandb_logger.log({f"{self.vgf_name}_train_loss": train_loss})
if (epoch + 1) % self.evaluate_steps == 0:
print_eval = (epoch + 1) % self.print_eval == 0
metric_watcher = self.evaluate(
eval_dataloader, epoch, print_eval=print_eval
with torch.no_grad():
eval_loss, metric_watcher = self.evaluate(
eval_dataloader, epoch, print_eval=print_eval
)
wandb_logger.log(
{
f"{self.vgf_name}_eval_loss": eval_loss,
f"{self.vgf_name}_eval_metric": metric_watcher,
}
)
stop = self.earlystopping(metric_watcher, self.model, epoch + 1)
if stop:
print(
@ -189,8 +193,9 @@ class Trainer:
self.device
)
break
print(f"- last swipe on eval set")
self.train_epoch(eval_dataloader, epoch=0)
self.train_epoch(eval_dataloader, epoch=-1)
self.earlystopping.save_model(self.model)
return self.model
@ -208,14 +213,7 @@ class Trainer:
if (epoch + 1) % PRINT_ON_EPOCH == 0:
if ((b_idx + 1) % self.print_steps == 0) or b_idx == 0:
print(f"Epoch: {epoch+1} Step: {b_idx+1} Loss: {loss:.4f}")
wandb.log(
{
f"{wandb.config['vgf']}_training_loss": loss,
# "epoch": epoch,
# f"{wandb.config['vgf']}_epoch": epoch,
}
)
return self
return loss.item()
def evaluate(self, dataloader, epoch, print_eval=True):
self.model.eval()
@ -242,15 +240,8 @@ class Trainer:
l_eval = evaluate(lY, lY_hat)
average_metrics = log_eval(l_eval, phase="validation", verbose=print_eval)
wandb.log(
{
f"{wandb.config['vgf']}_eval_metric": average_metrics[0],
f"{wandb.config['vgf']}_eval_loss": loss,
# "epoch": epoch,
# f"{wandb.config['vgf']}_epoch": epoch,
}
)
return average_metrics[0] # macro-F1
return loss.item(), average_metrics[0] # macro-F1
class EarlyStopping:

View File

@ -54,6 +54,7 @@ def main(args):
textual_transformer=args.textual_transformer,
textual_transformer_name=args.transformer_name,
batch_size=args.batch_size,
eval_batch_size=args.eval_batch_size,
epochs=args.epochs,
lr=args.lr,
max_length=args.max_length,
@ -125,6 +126,7 @@ if __name__ == "__main__":
# transformer parameters ---------------
parser.add_argument("--transformer_name", type=str, default="mbert")
parser.add_argument("--batch_size", type=int, default=32)
parser.add_argument("--eval_batch_size", type=int, default=128)
parser.add_argument("--epochs", type=int, default=100)
parser.add_argument("--lr", type=float, default=1e-5)
parser.add_argument("--max_length", type=int, default=128)