logging via wandb

This commit is contained in:
andreapdr 2023-03-07 14:20:56 +01:00
parent f274ec7615
commit 7dead90271
5 changed files with 67 additions and 10 deletions

3
.gitignore vendored
View File

@ -181,4 +181,5 @@ models/*
scripts/
logger/*
explore_data.ipynb
run.sh
run.sh
wandb

View File

@ -47,7 +47,7 @@ class GeneralizedFunnelling:
self.posteriors_vgf = posterior
self.wce_vgf = wce
self.multilingual_vgf = multilingual
self.trasformer_vgf = textual_transformer
self.textual_trasformer_vgf = textual_transformer
self.visual_transformer_vgf = visual_transformer
self.probabilistic = probabilistic
self.num_labels = num_labels
@ -142,7 +142,7 @@ class GeneralizedFunnelling:
wce_vgf = WceGen(n_jobs=self.n_jobs)
self.first_tier_learners.append(wce_vgf)
if self.trasformer_vgf:
if self.textual_trasformer_vgf:
transformer_vgf = TextualTransformerGen(
dataset_name=self.dataset_name,
model_name=self.textaul_transformer_name,
@ -198,7 +198,8 @@ class GeneralizedFunnelling:
self.posteriors_vgf,
self.multilingual_vgf,
self.wce_vgf,
self.trasformer_vgf,
self.textual_trasformer_vgf,
self.visual_transformer_vgf,
self.aggfunc,
)
print(f"- model id: {self._model_id}")
@ -372,7 +373,7 @@ class GeneralizedFunnelling:
"rb",
) as vgf:
first_tier_learners.append(pickle.load(vgf))
if self.trasformer_vgf:
if self.textual_trasformer_vgf:
with open(
os.path.join(
"models", "vgfs", "transformer", f"transformerGen_{model_id}.pkl"
@ -427,7 +428,15 @@ def get_params(optimc=False):
return [{"kernel": [kernel], "C": c_range, "gamma": ["auto"]}]
def get_unique_id(dataset_name, posterior, multilingual, wce, transformer, aggfunc):
def get_unique_id(
dataset_name,
posterior,
multilingual,
wce,
textual_transformer,
visual_transformer,
aggfunc,
):
from datetime import datetime
now = datetime.now().strftime("%y%m%d")
@ -435,6 +444,7 @@ def get_unique_id(dataset_name, posterior, multilingual, wce, transformer, aggfu
model_id += "p" if posterior else ""
model_id += "m" if multilingual else ""
model_id += "w" if wce else ""
model_id += "t" if transformer else ""
model_id += "t" if textual_transformer else ""
model_id += "v" if visual_transformer else ""
model_id += f"_{aggfunc}"
return f"{model_id}_{now}"

View File

@ -12,6 +12,7 @@ from torch.optim import AdamW
from torch.utils.data import DataLoader, Dataset
from transformers.modeling_outputs import ModelOutput
import wandb
from evaluation.evaluate import evaluate, log_eval
PRINT_ON_EPOCH = 1
@ -114,6 +115,7 @@ class Trainer:
patience,
experiment_name,
checkpoint_path,
vgf_name,
):
self.device = device
self.model = model.to(device)
@ -130,6 +132,7 @@ class Trainer:
verbose=False,
experiment_name=experiment_name,
)
self.vgf_name = vgf_name
def init_optimizer(self, optimizer_name, lr):
if optimizer_name.lower() == "adamw":
@ -138,6 +141,25 @@ class Trainer:
raise ValueError(f"Optimizer {optimizer_name} not supported")
def train(self, train_dataloader, eval_dataloader, epochs=10):
wandb.init(
project="gfun",
name="allhere",
# reinit=True,
config={
"vgf": self.vgf_name,
"architecture": self.model.name_or_path,
"learning_rate": self.optimizer.defaults["lr"],
"epochs": epochs,
"train batch size": train_dataloader.batch_size,
"eval batch size": eval_dataloader.batch_size,
"max len": train_dataloader.dataset.X.shape[-1],
"patience": self.earlystopping.patience,
"evaluate every": self.evaluate_steps,
"print eval every": self.print_eval,
"print train steps": self.print_steps,
},
)
print(
f"""- Training params for {self.experiment_name}:
- epochs: {epochs}
@ -150,11 +172,14 @@ class Trainer:
- print eval every: {self.print_eval}
- print train steps: {self.print_steps}\n"""
)
for epoch in range(epochs):
self.train_epoch(train_dataloader, epoch)
if (epoch + 1) % self.evaluate_steps == 0:
print_eval = (epoch + 1) % self.print_eval == 0
metric_watcher = self.evaluate(eval_dataloader, print_eval=print_eval)
metric_watcher = self.evaluate(
eval_dataloader, epoch, print_eval=print_eval
)
stop = self.earlystopping(metric_watcher, self.model, epoch + 1)
if stop:
print(
@ -183,9 +208,16 @@ class Trainer:
if (epoch + 1) % PRINT_ON_EPOCH == 0:
if ((b_idx + 1) % self.print_steps == 0) or b_idx == 0:
print(f"Epoch: {epoch+1} Step: {b_idx+1} Loss: {loss:.4f}")
wandb.log(
{
f"{wandb.config['vgf']}_training_loss": loss,
# "epoch": epoch,
# f"{wandb.config['vgf']}_epoch": epoch,
}
)
return self
def evaluate(self, dataloader, print_eval=True):
def evaluate(self, dataloader, epoch, print_eval=True):
self.model.eval()
lY = defaultdict(list)
@ -210,6 +242,14 @@ class Trainer:
l_eval = evaluate(lY, lY_hat)
average_metrics = log_eval(l_eval, phase="validation", verbose=print_eval)
wandb.log(
{
f"{wandb.config['vgf']}_eval_metric": average_metrics[0],
f"{wandb.config['vgf']}_eval_loss": loss,
# "epoch": epoch,
# f"{wandb.config['vgf']}_epoch": epoch,
}
)
return average_metrics[0] # macro-F1

View File

@ -130,6 +130,7 @@ class TextualTransformerGen(ViewGen, TransformerGen):
experiment_name = (
f"{self.model_name}-{self.epochs}-{self.batch_size}-{self.dataset_name}"
)
trainer = Trainer(
model=self.model,
optimizer_name="adamW",
@ -141,6 +142,7 @@ class TextualTransformerGen(ViewGen, TransformerGen):
patience=self.patience,
experiment_name=experiment_name,
checkpoint_path="models/vgfs/transformer",
vgf_name="textual_trf",
)
trainer.train(
train_dataloader=tra_dataloader,

View File

@ -97,7 +97,10 @@ class VisualTransformerGen(ViewGen, TransformerGen):
shuffle=False,
)
experiment_name = f"{self.model_name}-{self.epochs}-{self.batch_size}"
experiment_name = (
f"{self.model_name}-{self.epochs}-{self.batch_size}-{self.dataset_name}"
)
trainer = Trainer(
model=self.model,
optimizer_name="adamW",
@ -109,6 +112,7 @@ class VisualTransformerGen(ViewGen, TransformerGen):
patience=self.patience,
experiment_name=experiment_name,
checkpoint_path="models/vgfs/transformer",
vgf_name="visual_trf",
)
trainer.train(