logging via wandb
This commit is contained in:
parent
6b7917ca47
commit
84dd1f093e
|
@ -28,6 +28,7 @@ class GeneralizedFunnelling:
|
|||
embed_dir,
|
||||
n_jobs,
|
||||
batch_size,
|
||||
eval_batch_size,
|
||||
max_length,
|
||||
lr,
|
||||
epochs,
|
||||
|
@ -59,7 +60,8 @@ class GeneralizedFunnelling:
|
|||
self.textual_trf_name = textual_transformer_name
|
||||
self.epochs = epochs
|
||||
self.lr_transformer = lr
|
||||
self.batch_size_transformer = batch_size
|
||||
self.batch_size_trf = batch_size
|
||||
self.eval_batch_size_trf = eval_batch_size
|
||||
self.max_length = max_length
|
||||
self.early_stopping = True
|
||||
self.patience = patience
|
||||
|
@ -148,7 +150,8 @@ class GeneralizedFunnelling:
|
|||
model_name=self.textual_trf_name,
|
||||
lr=self.lr_transformer,
|
||||
epochs=self.epochs,
|
||||
batch_size=self.batch_size_transformer,
|
||||
batch_size=self.batch_size_trf,
|
||||
batch_size_eval=self.eval_batch_size_trf,
|
||||
max_length=self.max_length,
|
||||
print_steps=50,
|
||||
probabilistic=self.probabilistic,
|
||||
|
@ -163,10 +166,10 @@ class GeneralizedFunnelling:
|
|||
visual_trasformer_vgf = VisualTransformerGen(
|
||||
dataset_name=self.dataset_name,
|
||||
model_name="vit",
|
||||
lr=1e-5, # self.lr_visual_transformer,
|
||||
lr=self.lr_transformer,
|
||||
epochs=self.epochs,
|
||||
batch_size=32, # self.batch_size_visual_transformer,
|
||||
# batch_size_eval=128,
|
||||
batch_size=self.batch_size_trf,
|
||||
batch_size_eval=self.eval_batch_size_trf,
|
||||
probabilistic=self.probabilistic,
|
||||
evaluate_step=self.evaluate_step,
|
||||
patience=self.patience,
|
||||
|
|
|
@ -140,46 +140,50 @@ class Trainer:
|
|||
else:
|
||||
raise ValueError(f"Optimizer {optimizer_name} not supported")
|
||||
|
||||
def train(self, train_dataloader, eval_dataloader, epochs=10):
|
||||
wandb.init(
|
||||
project="gfun",
|
||||
name="allhere",
|
||||
# reinit=True,
|
||||
config={
|
||||
"vgf": self.vgf_name,
|
||||
"architecture": self.model.name_or_path,
|
||||
"learning_rate": self.optimizer.defaults["lr"],
|
||||
"epochs": epochs,
|
||||
"train batch size": train_dataloader.batch_size,
|
||||
"eval batch size": eval_dataloader.batch_size,
|
||||
"max len": train_dataloader.dataset.X.shape[-1],
|
||||
"patience": self.earlystopping.patience,
|
||||
"evaluate every": self.evaluate_steps,
|
||||
"print eval every": self.print_eval,
|
||||
"print train steps": self.print_steps,
|
||||
},
|
||||
)
|
||||
def get_config(self, train_dataloader, eval_dataloader, epochs):
|
||||
return {
|
||||
"model name": self.model.name_or_path,
|
||||
"epochs": epochs,
|
||||
"learning rate": self.optimizer.defaults["lr"],
|
||||
"train batch size": train_dataloader.batch_size,
|
||||
"eval batch size": eval_dataloader.batch_size,
|
||||
"max len": train_dataloader.dataset.X.shape[-1],
|
||||
"patience": self.earlystopping.patience,
|
||||
"evaluate every": self.evaluate_steps,
|
||||
"print eval every": self.print_eval,
|
||||
"print train steps": self.print_steps,
|
||||
}
|
||||
|
||||
print(
|
||||
f"""- Training params for {self.experiment_name}:
|
||||
- epochs: {epochs}
|
||||
- learning rate: {self.optimizer.defaults['lr']}
|
||||
- train batch size: {train_dataloader.batch_size}
|
||||
- eval batch size: {eval_dataloader.batch_size}
|
||||
- max len: {train_dataloader.dataset.X.shape[-1]}
|
||||
- patience: {self.earlystopping.patience}
|
||||
- evaluate every: {self.evaluate_steps}
|
||||
- print eval every: {self.print_eval}
|
||||
- print train steps: {self.print_steps}\n"""
|
||||
def train(self, train_dataloader, eval_dataloader, epochs=10):
|
||||
_config = self.get_config(train_dataloader, eval_dataloader, epochs)
|
||||
|
||||
print(f"- Training params for {self.experiment_name}:")
|
||||
for k, v in _config.items():
|
||||
print(f"\t{k}: {v}")
|
||||
|
||||
wandb_logger = wandb.init(
|
||||
project="gfun", entity="andreapdr", config=_config, reinit=True
|
||||
)
|
||||
|
||||
for epoch in range(epochs):
|
||||
self.train_epoch(train_dataloader, epoch)
|
||||
train_loss = self.train_epoch(train_dataloader, epoch)
|
||||
|
||||
wandb_logger.log({f"{self.vgf_name}_train_loss": train_loss})
|
||||
|
||||
if (epoch + 1) % self.evaluate_steps == 0:
|
||||
print_eval = (epoch + 1) % self.print_eval == 0
|
||||
metric_watcher = self.evaluate(
|
||||
eval_dataloader, epoch, print_eval=print_eval
|
||||
with torch.no_grad():
|
||||
eval_loss, metric_watcher = self.evaluate(
|
||||
eval_dataloader, epoch, print_eval=print_eval
|
||||
)
|
||||
|
||||
wandb_logger.log(
|
||||
{
|
||||
f"{self.vgf_name}_eval_loss": eval_loss,
|
||||
f"{self.vgf_name}_eval_metric": metric_watcher,
|
||||
}
|
||||
)
|
||||
|
||||
stop = self.earlystopping(metric_watcher, self.model, epoch + 1)
|
||||
if stop:
|
||||
print(
|
||||
|
@ -189,8 +193,9 @@ class Trainer:
|
|||
self.device
|
||||
)
|
||||
break
|
||||
|
||||
print(f"- last swipe on eval set")
|
||||
self.train_epoch(eval_dataloader, epoch=0)
|
||||
self.train_epoch(eval_dataloader, epoch=-1)
|
||||
self.earlystopping.save_model(self.model)
|
||||
return self.model
|
||||
|
||||
|
@ -208,14 +213,7 @@ class Trainer:
|
|||
if (epoch + 1) % PRINT_ON_EPOCH == 0:
|
||||
if ((b_idx + 1) % self.print_steps == 0) or b_idx == 0:
|
||||
print(f"Epoch: {epoch+1} Step: {b_idx+1} Loss: {loss:.4f}")
|
||||
wandb.log(
|
||||
{
|
||||
f"{wandb.config['vgf']}_training_loss": loss,
|
||||
# "epoch": epoch,
|
||||
# f"{wandb.config['vgf']}_epoch": epoch,
|
||||
}
|
||||
)
|
||||
return self
|
||||
return loss.item()
|
||||
|
||||
def evaluate(self, dataloader, epoch, print_eval=True):
|
||||
self.model.eval()
|
||||
|
@ -242,15 +240,8 @@ class Trainer:
|
|||
|
||||
l_eval = evaluate(lY, lY_hat)
|
||||
average_metrics = log_eval(l_eval, phase="validation", verbose=print_eval)
|
||||
wandb.log(
|
||||
{
|
||||
f"{wandb.config['vgf']}_eval_metric": average_metrics[0],
|
||||
f"{wandb.config['vgf']}_eval_loss": loss,
|
||||
# "epoch": epoch,
|
||||
# f"{wandb.config['vgf']}_epoch": epoch,
|
||||
}
|
||||
)
|
||||
return average_metrics[0] # macro-F1
|
||||
|
||||
return loss.item(), average_metrics[0] # macro-F1
|
||||
|
||||
|
||||
class EarlyStopping:
|
||||
|
|
2
main.py
2
main.py
|
@ -54,6 +54,7 @@ def main(args):
|
|||
textual_transformer=args.textual_transformer,
|
||||
textual_transformer_name=args.transformer_name,
|
||||
batch_size=args.batch_size,
|
||||
eval_batch_size=args.eval_batch_size,
|
||||
epochs=args.epochs,
|
||||
lr=args.lr,
|
||||
max_length=args.max_length,
|
||||
|
@ -125,6 +126,7 @@ if __name__ == "__main__":
|
|||
# transformer parameters ---------------
|
||||
parser.add_argument("--transformer_name", type=str, default="mbert")
|
||||
parser.add_argument("--batch_size", type=int, default=32)
|
||||
parser.add_argument("--eval_batch_size", type=int, default=128)
|
||||
parser.add_argument("--epochs", type=int, default=100)
|
||||
parser.add_argument("--lr", type=float, default=1e-5)
|
||||
parser.add_argument("--max_length", type=int, default=128)
|
||||
|
|
Loading…
Reference in New Issue