improved wandb logging

This commit is contained in:
andreapdr 2023-03-09 17:03:17 +01:00
parent 3240150542
commit 7e1ec46ebd
6 changed files with 215 additions and 90 deletions

View File

@ -1,46 +1,96 @@
from joblib import Parallel, delayed from joblib import Parallel, delayed
from collections import defaultdict
from evaluation.metrics import * from evaluation.metrics import *
from sklearn.metrics import accuracy_score, top_k_accuracy_score, f1_score
def evaluation_metrics(y, y_): def evaluation_metrics(y, y_, clf_type):
if len(y.shape) == len(y_.shape) == 1 and len(np.unique(y)) > 2: # single-label if clf_type == "singlelabel":
raise NotImplementedError() return (
else: accuracy_score(y, y_),
# TODO: we need the logits to compute this top_k_accuracy_score(y, y_, k=5),
# TODO: we need logits top_k_accuracy_score(y, y_, k=10),
f1_score(y, y_, average="macro", zero_division=1),
f1_score(y, y_, average="micro"),
)
elif clf_type == "multilabel":
return ( return (
macroF1(y, y_), macroF1(y, y_),
microF1(y, y_), microF1(y, y_),
macroK(y, y_), macroK(y, y_),
microK(y, y_), microK(y, y_),
# macroAcc(y, y_),
) )
else:
raise ValueError("clf_type must be either 'singlelabel' or 'multilabel'")
def evaluate(ly_true, ly_pred, metrics=evaluation_metrics, n_jobs=-1): def evaluate(
ly_true, ly_pred, metrics=evaluation_metrics, n_jobs=-1, clf_type="multilabel"
):
if n_jobs == 1: if n_jobs == 1:
return {lang: metrics(ly_true[lang], ly_pred[lang]) for lang in ly_true.keys()} return {
lang: metrics(ly_true[lang], ly_pred[lang], clf_type)
for lang in ly_true.keys()
}
else: else:
langs = list(ly_true.keys()) langs = list(ly_true.keys())
evals = Parallel(n_jobs=n_jobs)( evals = Parallel(n_jobs=n_jobs)(
delayed(metrics)(ly_true[lang], ly_pred[lang]) for lang in langs delayed(metrics)(ly_true[lang], ly_pred[lang], clf_type) for lang in langs
) )
return {lang: evals[i] for i, lang in enumerate(langs)} return {lang: evals[i] for i, lang in enumerate(langs)}
def log_eval(l_eval, phase="training", verbose=True): def log_eval(l_eval, phase="training", clf_type="multilabel", verbose=True):
if verbose: if verbose:
print(f"\n[Results {phase}]") print(f"\n[Results {phase}]")
metrics = [] metrics = []
for lang in l_eval.keys():
macrof1, microf1, macrok, microk = l_eval[lang] if clf_type == "multilabel":
metrics.append([macrof1, microf1, macrok, microk]) for lang in l_eval.keys():
if phase != "validation": macrof1, microf1, macrok, microk = l_eval[lang]
print(f"Lang {lang}: macro-F1 = {macrof1:.3f} micro-F1 = {microf1:.3f}") metrics.append([macrof1, microf1, macrok, microk])
averages = np.mean(np.array(metrics), axis=0) if phase != "validation":
if verbose: print(f"Lang {lang}: macro-F1 = {macrof1:.3f} micro-F1 = {microf1:.3f}")
print( averages = np.mean(np.array(metrics), axis=0)
"Averages: MF1, mF1, MK, mK", if verbose:
np.round(averages, 3), print(
"\n", "Averages: MF1, mF1, MK, mK",
) np.round(averages, 3),
return averages "\n",
)
return averages # TODO: return a dict avg and lang specific
elif clf_type == "singlelabel":
lang_metrics = defaultdict(dict)
_metrics = [
"accuracy",
# "acc5", # "accuracy-at-5",
# "acc10", # "accuracy-at-10",
"MF1", # "macro-F1",
"mF1", # "micro-F1",
]
for lang in l_eval.keys():
# acc, top5, top10, macrof1, microf1 = l_eval[lang]
acc, macrof1, microf1 = l_eval[lang]
# metrics.append([acc, top5, top10, macrof1, microf1])
metrics.append([acc, macrof1, microf1])
for m, v in zip(_metrics, l_eval[lang]):
lang_metrics[m][lang] = v
if phase != "validation":
print(
# f"Lang {lang}: acc = {acc:.3f} acc-top5 = {top5:.3f} acc-top10 = {top10:.3f} macro-F1: {macrof1:.3f} micro-F1 = {microf1:.3f}"
f"Lang {lang}: acc = {acc:.3f} macro-F1: {macrof1:.3f} micro-F1 = {microf1:.3f}"
)
averages = np.mean(np.array(metrics), axis=0)
if verbose:
print(
# "Averages: Acc, Acc-5, Acc-10, MF1, mF1",
"Averages: Acc, MF1, mF1",
np.round(averages, 3),
"\n",
)
avg_metrics = dict(zip(_metrics, averages))
return avg_metrics, lang_metrics

View File

@ -1,17 +1,14 @@
import os import os
import sys
# sys.path.append(os.path.join(os.getcwd(), "gfun"))
import pickle import pickle
import numpy as np import numpy as np
from gfun.vgfs.commons import TfidfVectorizerMultilingual, AttentionAggregator
from gfun.vgfs.commons import AttentionAggregator, TfidfVectorizerMultilingual, predict
from gfun.vgfs.learners.svms import MetaClassifier, get_learner from gfun.vgfs.learners.svms import MetaClassifier, get_learner
from gfun.vgfs.multilingualGen import MultilingualGen from gfun.vgfs.multilingualGen import MultilingualGen
from gfun.vgfs.textualTransformerGen import TextualTransformerGen from gfun.vgfs.textualTransformerGen import TextualTransformerGen
from gfun.vgfs.visualTransformerGen import VisualTransformerGen
from gfun.vgfs.vanillaFun import VanillaFunGen from gfun.vgfs.vanillaFun import VanillaFunGen
from gfun.vgfs.visualTransformerGen import VisualTransformerGen
from gfun.vgfs.wceGen import WceGen from gfun.vgfs.wceGen import WceGen
@ -25,12 +22,14 @@ class GeneralizedFunnelling:
visual_transformer, visual_transformer,
langs, langs,
num_labels, num_labels,
classification_type,
embed_dir, embed_dir,
n_jobs, n_jobs,
batch_size, batch_size,
eval_batch_size, eval_batch_size,
max_length, max_length,
lr, textual_lr,
visual_lr,
epochs, epochs,
patience, patience,
evaluate_step, evaluate_step,
@ -52,6 +51,7 @@ class GeneralizedFunnelling:
self.visual_trf_vgf = visual_transformer self.visual_trf_vgf = visual_transformer
self.probabilistic = probabilistic self.probabilistic = probabilistic
self.num_labels = num_labels self.num_labels = num_labels
self.clf_type = classification_type
# ------------------------ # ------------------------
self.langs = langs self.langs = langs
self.embed_dir = embed_dir self.embed_dir = embed_dir
@ -59,7 +59,8 @@ class GeneralizedFunnelling:
# Textual Transformer VGF params ---------- # Textual Transformer VGF params ----------
self.textual_trf_name = textual_transformer_name self.textual_trf_name = textual_transformer_name
self.epochs = epochs self.epochs = epochs
self.lr_transformer = lr self.txt_trf_lr = textual_lr
self.vis_trf_lr = visual_lr
self.batch_size_trf = batch_size self.batch_size_trf = batch_size
self.eval_batch_size_trf = eval_batch_size self.eval_batch_size_trf = eval_batch_size
self.max_length = max_length self.max_length = max_length
@ -114,7 +115,7 @@ class GeneralizedFunnelling:
self.attn_aggregator = AttentionAggregator( self.attn_aggregator = AttentionAggregator(
embed_dim=self.get_attn_agg_dim(attn_stacking_type=attn_stacking), embed_dim=self.get_attn_agg_dim(attn_stacking_type=attn_stacking),
out_dim=self.num_labels, out_dim=self.num_labels,
lr=self.lr_transformer, lr=self.txt_trf_lr,
patience=self.patience, patience=self.patience,
num_heads=1, num_heads=1,
device=self.device, device=self.device,
@ -148,7 +149,7 @@ class GeneralizedFunnelling:
transformer_vgf = TextualTransformerGen( transformer_vgf = TextualTransformerGen(
dataset_name=self.dataset_name, dataset_name=self.dataset_name,
model_name=self.textual_trf_name, model_name=self.textual_trf_name,
lr=self.lr_transformer, lr=self.txt_trf_lr,
epochs=self.epochs, epochs=self.epochs,
batch_size=self.batch_size_trf, batch_size=self.batch_size_trf,
batch_size_eval=self.eval_batch_size_trf, batch_size_eval=self.eval_batch_size_trf,
@ -159,6 +160,7 @@ class GeneralizedFunnelling:
verbose=True, verbose=True,
patience=self.patience, patience=self.patience,
device=self.device, device=self.device,
classification_type=self.clf_type,
) )
self.first_tier_learners.append(transformer_vgf) self.first_tier_learners.append(transformer_vgf)
@ -166,7 +168,7 @@ class GeneralizedFunnelling:
visual_trasformer_vgf = VisualTransformerGen( visual_trasformer_vgf = VisualTransformerGen(
dataset_name=self.dataset_name, dataset_name=self.dataset_name,
model_name="vit", model_name="vit",
lr=self.lr_transformer, lr=self.vis_trf_lr,
epochs=self.epochs, epochs=self.epochs,
batch_size=self.batch_size_trf, batch_size=self.batch_size_trf,
batch_size_eval=self.eval_batch_size_trf, batch_size_eval=self.eval_batch_size_trf,
@ -174,6 +176,7 @@ class GeneralizedFunnelling:
evaluate_step=self.evaluate_step, evaluate_step=self.evaluate_step,
patience=self.patience, patience=self.patience,
device=self.device, device=self.device,
classification_type=self.clf_type,
) )
self.first_tier_learners.append(visual_trasformer_vgf) self.first_tier_learners.append(visual_trasformer_vgf)
@ -182,7 +185,7 @@ class GeneralizedFunnelling:
self.attn_aggregator = AttentionAggregator( self.attn_aggregator = AttentionAggregator(
embed_dim=self.get_attn_agg_dim(attn_stacking_type=attn_stacking), embed_dim=self.get_attn_agg_dim(attn_stacking_type=attn_stacking),
out_dim=self.num_labels, out_dim=self.num_labels,
lr=self.lr_transformer, lr=self.txt_trf_lr,
patience=self.patience, patience=self.patience,
num_heads=1, num_heads=1,
device=self.device, device=self.device,
@ -255,10 +258,9 @@ class GeneralizedFunnelling:
projections.append(l_posteriors) projections.append(l_posteriors)
agg = self.aggregate(projections) agg = self.aggregate(projections)
l_out = self.metaclassifier.predict_proba(agg) l_out = self.metaclassifier.predict_proba(agg)
# converting to binary predictions if self.clf_type == "singlelabel":
# if self.dataset_name in ["cls"]: # TODO: better way to do this for lang, preds in l_out.items():
# for lang, preds in l_out.items(): l_out[lang] = predict(preds, clf_type=self.clf_type)
# l_out[lang] = np.expand_dims(np.argmax(preds, axis=1), axis=1)
return l_out return l_out
def fit_transform(self, lX, lY): def fit_transform(self, lX, lY):

View File

@ -9,6 +9,7 @@ from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.model_selection import train_test_split from sklearn.model_selection import train_test_split
from sklearn.preprocessing import normalize from sklearn.preprocessing import normalize
from torch.optim import AdamW from torch.optim import AdamW
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import DataLoader, Dataset from torch.utils.data import DataLoader, Dataset
from transformers.modeling_outputs import ModelOutput from transformers.modeling_outputs import ModelOutput
@ -22,6 +23,21 @@ def _normalize(lX, l2=True):
return {lang: normalize(np.asarray(X)) for lang, X in lX.items()} if l2 else lX return {lang: normalize(np.asarray(X)) for lang, X in lX.items()} if l2 else lX
def verbosity_eval(epoch, print_eval):
if (epoch + 1) % print_eval == 0 and epoch != 0:
return True
else:
return False
def format_langkey_wandb(lang_dict):
log_dict = {}
for metric, l_dict in lang_dict.items():
for lang, value in l_dict.items():
log_dict[f"language metric/{metric}/{lang}"] = value
return log_dict
def XdotM(X, M, sif): def XdotM(X, M, sif):
E = X.dot(M) E = X.dot(M)
if sif: if sif:
@ -58,18 +74,23 @@ def compute_pc(X, npc=1):
return svd.components_ return svd.components_
def predict(logits, classification_type="multilabel"): def predict(logits, clf_type="multilabel"):
""" """
Converts soft precictions to hard predictions [0,1] Converts soft precictions to hard predictions [0,1]
""" """
if classification_type == "multilabel": if clf_type == "multilabel":
prediction = torch.sigmoid(logits) > 0.5 prediction = torch.sigmoid(logits) > 0.5
elif classification_type == "singlelabel": return prediction.detach().cpu().numpy()
prediction = torch.argmax(logits, dim=1).view(-1, 1) elif clf_type == "singlelabel":
if type(logits) != torch.Tensor:
logits = torch.tensor(logits)
prediction = torch.softmax(logits, dim=1)
prediction = prediction.detach().cpu().numpy()
_argmaxs = prediction.argmax(axis=1)
prediction = np.eye(prediction.shape[1])[_argmaxs]
return prediction
else: else:
print("unknown classification type") raise NotImplementedError()
return prediction.detach().cpu().numpy()
class TfidfVectorizerMultilingual: class TfidfVectorizerMultilingual:
@ -115,36 +136,54 @@ class Trainer:
patience, patience,
experiment_name, experiment_name,
checkpoint_path, checkpoint_path,
classification_type,
vgf_name, vgf_name,
n_jobs,
scheduler_name=None,
): ):
self.device = device self.device = device
self.model = model.to(device) self.model = model.to(device)
self.optimizer = self.init_optimizer(optimizer_name, lr) self.optimizer, self.scheduler = self.init_optimizer(
optimizer_name, lr, scheduler_name
)
self.evaluate_steps = evaluate_step self.evaluate_steps = evaluate_step
self.loss_fn = loss_fn.to(device) self.loss_fn = loss_fn.to(device)
self.print_steps = print_steps self.print_steps = print_steps
self.experiment_name = experiment_name self.experiment_name = experiment_name
self.patience = patience self.patience = patience
self.print_eval = evaluate_step self.print_eval = 10
self.earlystopping = EarlyStopping( self.earlystopping = EarlyStopping(
patience=patience, patience=patience,
checkpoint_path=checkpoint_path, checkpoint_path=checkpoint_path,
verbose=False, verbose=False,
experiment_name=experiment_name, experiment_name=experiment_name,
) )
self.clf_type = classification_type
self.vgf_name = vgf_name self.vgf_name = vgf_name
self.n_jobs = n_jobs
self.monitored_metric = (
"macro-F1" if self.clf_type == "multilabel" else "accuracy"
) # TODO: make this configurable
def init_optimizer(self, optimizer_name, lr): def init_optimizer(self, optimizer_name, lr, scheduler_name):
if optimizer_name.lower() == "adamw": if optimizer_name.lower() == "adamw":
return AdamW(self.model.parameters(), lr=lr) optim = AdamW(self.model.parameters(), lr=lr)
else: else:
raise ValueError(f"Optimizer {optimizer_name} not supported") raise ValueError(f"Optimizer {optimizer_name} not supported")
if scheduler_name is None:
scheduler = None
elif scheduler_name == "ReduceLROnPlateau":
scheduler = ReduceLROnPlateau(optim, "max", factor=0.5, min_lr=1e-5)
else:
raise ValueError(f"Scheduler {scheduler_name} not supported")
return optim, scheduler
def get_config(self, train_dataloader, eval_dataloader, epochs): def get_config(self, train_dataloader, eval_dataloader, epochs):
return { return {
"model name": self.model.name_or_path, "model name": self.model.name_or_path,
"epochs": epochs, "epochs": epochs,
"learning rate": self.optimizer.defaults["lr"], "learning rate": self.optimizer.defaults["lr"],
"scheduler": "TODO", # TODO: add scheduler name
"train batch size": train_dataloader.batch_size, "train batch size": train_dataloader.batch_size,
"eval batch size": eval_dataloader.batch_size, "eval batch size": eval_dataloader.batch_size,
"max len": train_dataloader.dataset.X.shape[-1], "max len": train_dataloader.dataset.X.shape[-1],
@ -152,6 +191,7 @@ class Trainer:
"evaluate every": self.evaluate_steps, "evaluate every": self.evaluate_steps,
"print eval every": self.print_eval, "print eval every": self.print_eval,
"print train steps": self.print_steps, "print train steps": self.print_steps,
"classification type": self.clf_type,
} }
def train(self, train_dataloader, eval_dataloader, epochs=10): def train(self, train_dataloader, eval_dataloader, epochs=10):
@ -168,23 +208,23 @@ class Trainer:
for epoch in range(epochs): for epoch in range(epochs):
train_loss = self.train_epoch(train_dataloader, epoch) train_loss = self.train_epoch(train_dataloader, epoch)
wandb_logger.log({f"{self.vgf_name}_train_loss": train_loss}) if (epoch + 1) % self.evaluate_steps == 0 or (epoch + 1) == 1:
print_eval = verbosity_eval(epoch, self.print_eval)
if (epoch + 1) % self.evaluate_steps == 0:
print_eval = (epoch + 1) % self.print_eval == 0
with torch.no_grad(): with torch.no_grad():
eval_loss, metric_watcher = self.evaluate( eval_loss, avg_metrics, lang_metrics = self.evaluate(
eval_dataloader, epoch, print_eval=print_eval eval_dataloader,
print_eval=print_eval,
n_jobs=self.n_jobs,
) )
wandb_logger.log( wandb_logger.log(
{ {"loss/val": eval_loss, **format_langkey_wandb(lang_metrics)},
f"{self.vgf_name}_eval_loss": eval_loss, commit=False,
f"{self.vgf_name}_eval_metric": metric_watcher,
}
) )
stop = self.earlystopping(metric_watcher, self.model, epoch + 1) stop = self.earlystopping(
avg_metrics[self.monitored_metric], self.model, epoch + 1
)
if stop: if stop:
print( print(
f"- restoring best model from epoch {self.earlystopping.best_epoch} with best metric: {self.earlystopping.best_score:3f}" f"- restoring best model from epoch {self.earlystopping.best_epoch} with best metric: {self.earlystopping.best_score:3f}"
@ -194,6 +234,16 @@ class Trainer:
) )
break break
if self.scheduler is not None:
self.scheduler.step(avg_metrics[self.monitored_metric])
wandb_logger.log(
{
"loss/train": train_loss,
"learning rate": self.optimizer.param_groups[0]["lr"],
}
)
print(f"- last swipe on eval set") print(f"- last swipe on eval set")
self.train_epoch(eval_dataloader, epoch=-1) self.train_epoch(eval_dataloader, epoch=-1)
self.earlystopping.save_model(self.model) self.earlystopping.save_model(self.model)
@ -201,6 +251,7 @@ class Trainer:
def train_epoch(self, dataloader, epoch): def train_epoch(self, dataloader, epoch):
self.model.train() self.model.train()
epoch_losses = []
for b_idx, (x, y, lang) in enumerate(dataloader): for b_idx, (x, y, lang) in enumerate(dataloader):
self.optimizer.zero_grad() self.optimizer.zero_grad()
y_hat = self.model(x.to(self.device)) y_hat = self.model(x.to(self.device))
@ -210,38 +261,47 @@ class Trainer:
loss = self.loss_fn(y_hat, y.to(self.device)) loss = self.loss_fn(y_hat, y.to(self.device))
loss.backward() loss.backward()
self.optimizer.step() self.optimizer.step()
epoch_losses.append(loss.item())
if (epoch + 1) % PRINT_ON_EPOCH == 0: if (epoch + 1) % PRINT_ON_EPOCH == 0:
if ((b_idx + 1) % self.print_steps == 0) or b_idx == 0: if ((b_idx + 1) % self.print_steps == 0) or b_idx == 0:
print(f"Epoch: {epoch+1} Step: {b_idx+1} Loss: {loss:.4f}") print(
return loss.item() f"Epoch: {epoch+1} Step: {b_idx+1} Loss: {np.mean(epoch_losses):.4f}"
)
return np.mean(epoch_losses)
def evaluate(self, dataloader, epoch, print_eval=True): def evaluate(self, dataloader, print_eval=True, n_jobs=-1):
self.model.eval() self.model.eval()
eval_losses = []
lY = defaultdict(list) lY_true = defaultdict(list)
lY_hat = defaultdict(list) lY_pred = defaultdict(list)
for b_idx, (x, y, lang) in enumerate(dataloader): for b_idx, (x, y, lang) in enumerate(dataloader):
y_hat = self.model(x.to(self.device)) y_pred = self.model(x.to(self.device))
if isinstance(y_hat, ModelOutput): if isinstance(y_pred, ModelOutput):
loss = self.loss_fn(y_hat.logits, y.to(self.device)) loss = self.loss_fn(y_pred.logits, y.to(self.device))
predictions = predict(y_hat.logits, classification_type="multilabel") predictions = predict(y_pred.logits, clf_type=self.clf_type)
else: else:
loss = self.loss_fn(y_hat, y.to(self.device)) loss = self.loss_fn(y_pred, y.to(self.device))
predictions = predict(y_hat, classification_type="multilabel") predictions = predict(y_pred, clf_type=self.clf_type)
eval_losses.append(loss.item())
for l, _true, _pred in zip(lang, y, predictions): for l, _true, _pred in zip(lang, y, predictions):
lY[l].append(_true.detach().cpu().numpy()) lY_true[l].append(_true.detach().cpu().numpy())
lY_hat[l].append(_pred) lY_pred[l].append(_pred)
for lang in lY: for lang in lY_true:
lY[lang] = np.vstack(lY[lang]) lY_true[lang] = np.vstack(lY_true[lang])
lY_hat[lang] = np.vstack(lY_hat[lang]) lY_pred[lang] = np.vstack(lY_pred[lang])
l_eval = evaluate(lY, lY_hat) l_eval = evaluate(lY_true, lY_pred, clf_type=self.clf_type, n_jobs=n_jobs)
average_metrics = log_eval(l_eval, phase="validation", verbose=print_eval)
return loss.item(), average_metrics[0] # macro-F1 avg_metrics, lang_metrics = log_eval(
l_eval, phase="validation", clf_type=self.clf_type, verbose=print_eval
)
return np.mean(eval_losses), avg_metrics, lang_metrics
class EarlyStopping: class EarlyStopping:
@ -279,7 +339,7 @@ class EarlyStopping:
print( print(
f"- earlystopping: Validation score decreased from {self.best_score:.3f} to {validation:.3f}, current patience: {self.patience - self.counter}" f"- earlystopping: Validation score decreased from {self.best_score:.3f} to {validation:.3f}, current patience: {self.patience - self.counter}"
) )
if self.counter >= self.patience: if self.counter >= self.patience and self.patience != -1:
print(f"- earlystopping: Early stopping at epoch {epoch}") print(f"- earlystopping: Early stopping at epoch {epoch}")
return True return True

View File

@ -7,11 +7,9 @@ from collections import defaultdict
import numpy as np import numpy as np
import torch import torch
import transformers import transformers
# from sklearn.model_selection import train_test_split
# from torch.optim import AdamW
from torch.utils.data import Dataset from torch.utils.data import Dataset
from transformers import AutoModelForSequenceClassification, AutoTokenizer from transformers import AutoModelForSequenceClassification, AutoTokenizer
from gfun.vgfs.commons import Trainer from gfun.vgfs.commons import Trainer
from gfun.vgfs.transformerGen import TransformerGen from gfun.vgfs.transformerGen import TransformerGen
from gfun.vgfs.viewGen import ViewGen from gfun.vgfs.viewGen import ViewGen
@ -19,9 +17,6 @@ from gfun.vgfs.viewGen import ViewGen
transformers.logging.set_verbosity_error() transformers.logging.set_verbosity_error()
# TODO: add support to loggers
class TextualTransformerGen(ViewGen, TransformerGen): class TextualTransformerGen(ViewGen, TransformerGen):
def __init__( def __init__(
self, self,
@ -39,6 +34,7 @@ class TextualTransformerGen(ViewGen, TransformerGen):
evaluate_step=10, evaluate_step=10,
verbose=False, verbose=False,
patience=5, patience=5,
classification_type="multilabel",
): ):
super().__init__( super().__init__(
self._validate_model_name(model_name), self._validate_model_name(model_name),
@ -56,6 +52,7 @@ class TextualTransformerGen(ViewGen, TransformerGen):
verbose, verbose,
patience, patience,
) )
self.clf_type = classification_type
self.fitted = False self.fitted = False
print( print(
f"- init Textual TransformerModel model_name: {self.model_name}, device: {self.device}]" f"- init Textual TransformerModel model_name: {self.model_name}, device: {self.device}]"
@ -143,6 +140,9 @@ class TextualTransformerGen(ViewGen, TransformerGen):
experiment_name=experiment_name, experiment_name=experiment_name,
checkpoint_path="models/vgfs/transformer", checkpoint_path="models/vgfs/transformer",
vgf_name="textual_trf", vgf_name="textual_trf",
classification_type=self.clf_type,
n_jobs=self.n_jobs,
scheduler_name="ReduceLROnPlateau",
) )
trainer.train( trainer.train(
train_dataloader=tra_dataloader, train_dataloader=tra_dataloader,

View File

@ -27,6 +27,7 @@ class VisualTransformerGen(ViewGen, TransformerGen):
device="cpu", device="cpu",
probabilistic=False, probabilistic=False,
patience=5, patience=5,
classification_type="multilabel",
): ):
super().__init__( super().__init__(
model_name, model_name,
@ -40,6 +41,7 @@ class VisualTransformerGen(ViewGen, TransformerGen):
patience=patience, patience=patience,
probabilistic=probabilistic, probabilistic=probabilistic,
) )
self.clf_type = classification_type
self.fitted = False self.fitted = False
print( print(
f"- init Visual TransformerModel model_name: {self.model_name}, device: {self.device}]" f"- init Visual TransformerModel model_name: {self.model_name}, device: {self.device}]"
@ -113,6 +115,8 @@ class VisualTransformerGen(ViewGen, TransformerGen):
experiment_name=experiment_name, experiment_name=experiment_name,
checkpoint_path="models/vgfs/transformer", checkpoint_path="models/vgfs/transformer",
vgf_name="visual_trf", vgf_name="visual_trf",
classification_type=self.clf_type,
n_jobs=self.n_jobs,
) )
trainer.train( trainer.train(

17
main.py
View File

@ -1,3 +1,7 @@
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
from argparse import ArgumentParser from argparse import ArgumentParser
from time import time from time import time
@ -7,6 +11,7 @@ from gfun.generalizedFunnelling import GeneralizedFunnelling
""" """
TODO: TODO:
- [!] LR scheduler
- [!] CLS dataset is loading only "books" domain data - [!] CLS dataset is loading only "books" domain data
- [!] documents should be trimmed to the same length (?) - [!] documents should be trimmed to the same length (?)
- [!] overall gfun results logger - [!] overall gfun results logger
@ -42,6 +47,7 @@ def main(args):
dataset_name=args.dataset, dataset_name=args.dataset,
langs=dataset.langs(), langs=dataset.langs(),
num_labels=dataset.num_labels(), num_labels=dataset.num_labels(),
classification_type=args.clf_type,
# Posterior VGF params ---------------- # Posterior VGF params ----------------
posterior=args.posteriors, posterior=args.posteriors,
# Multilingual VGF params ------------- # Multilingual VGF params -------------
@ -55,7 +61,8 @@ def main(args):
batch_size=args.batch_size, batch_size=args.batch_size,
eval_batch_size=args.eval_batch_size, eval_batch_size=args.eval_batch_size,
epochs=args.epochs, epochs=args.epochs,
lr=args.lr, textual_lr=args.textual_lr,
visual_lr=args.visual_lr,
max_length=args.max_length, max_length=args.max_length,
patience=args.patience, patience=args.patience,
evaluate_step=args.evaluate_step, evaluate_step=args.evaluate_step,
@ -93,8 +100,8 @@ def main(args):
print(f"- training completed in {timetr - tinit:.2f} seconds") print(f"- training completed in {timetr - tinit:.2f} seconds")
gfun_preds = gfun.transform(lX_te) gfun_preds = gfun.transform(lX_te)
test_eval = evaluate(lY_te, gfun_preds) test_eval = evaluate(lY_te, gfun_preds, clf_type=args.clf_type, n_jobs=args.n_jobs)
log_eval(test_eval, phase="test") log_eval(test_eval, phase="test", clf_type=args.clf_type)
timeval = time() timeval = time()
print(f"- testing completed in {timeval - timetr:.2f} seconds") print(f"- testing completed in {timeval - timetr:.2f} seconds")
@ -112,6 +119,7 @@ if __name__ == "__main__":
parser.add_argument("--nrows", type=int, default=None) parser.add_argument("--nrows", type=int, default=None)
parser.add_argument("--min_count", type=int, default=10) parser.add_argument("--min_count", type=int, default=10)
parser.add_argument("--max_labels", type=int, default=50) parser.add_argument("--max_labels", type=int, default=50)
parser.add_argument("--clf_type", type=str, default="multilabel")
# gFUN parameters ---------------------- # gFUN parameters ----------------------
parser.add_argument("-p", "--posteriors", action="store_true") parser.add_argument("-p", "--posteriors", action="store_true")
parser.add_argument("-m", "--multilingual", action="store_true") parser.add_argument("-m", "--multilingual", action="store_true")
@ -127,7 +135,8 @@ if __name__ == "__main__":
parser.add_argument("--batch_size", type=int, default=32) parser.add_argument("--batch_size", type=int, default=32)
parser.add_argument("--eval_batch_size", type=int, default=128) parser.add_argument("--eval_batch_size", type=int, default=128)
parser.add_argument("--epochs", type=int, default=100) parser.add_argument("--epochs", type=int, default=100)
parser.add_argument("--lr", type=float, default=1e-5) parser.add_argument("--textual_lr", type=float, default=1e-5)
parser.add_argument("--visual_lr", type=float, default=1e-5)
parser.add_argument("--max_length", type=int, default=128) parser.add_argument("--max_length", type=int, default=128)
parser.add_argument("--patience", type=int, default=5) parser.add_argument("--patience", type=int, default=5)
parser.add_argument("--evaluate_step", type=int, default=10) parser.add_argument("--evaluate_step", type=int, default=10)