improved wandb logging
This commit is contained in:
parent
3240150542
commit
7e1ec46ebd
|
@ -1,46 +1,96 @@
|
||||||
from joblib import Parallel, delayed
|
from joblib import Parallel, delayed
|
||||||
|
from collections import defaultdict
|
||||||
|
|
||||||
from evaluation.metrics import *
|
from evaluation.metrics import *
|
||||||
|
from sklearn.metrics import accuracy_score, top_k_accuracy_score, f1_score
|
||||||
|
|
||||||
|
|
||||||
def evaluation_metrics(y, y_):
|
def evaluation_metrics(y, y_, clf_type):
|
||||||
if len(y.shape) == len(y_.shape) == 1 and len(np.unique(y)) > 2: # single-label
|
if clf_type == "singlelabel":
|
||||||
raise NotImplementedError()
|
return (
|
||||||
else:
|
accuracy_score(y, y_),
|
||||||
|
# TODO: we need the logits to compute this top_k_accuracy_score(y, y_, k=5),
|
||||||
|
# TODO: we need logits top_k_accuracy_score(y, y_, k=10),
|
||||||
|
f1_score(y, y_, average="macro", zero_division=1),
|
||||||
|
f1_score(y, y_, average="micro"),
|
||||||
|
)
|
||||||
|
elif clf_type == "multilabel":
|
||||||
return (
|
return (
|
||||||
macroF1(y, y_),
|
macroF1(y, y_),
|
||||||
microF1(y, y_),
|
microF1(y, y_),
|
||||||
macroK(y, y_),
|
macroK(y, y_),
|
||||||
microK(y, y_),
|
microK(y, y_),
|
||||||
# macroAcc(y, y_),
|
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError("clf_type must be either 'singlelabel' or 'multilabel'")
|
||||||
|
|
||||||
|
|
||||||
def evaluate(ly_true, ly_pred, metrics=evaluation_metrics, n_jobs=-1):
|
def evaluate(
|
||||||
|
ly_true, ly_pred, metrics=evaluation_metrics, n_jobs=-1, clf_type="multilabel"
|
||||||
|
):
|
||||||
if n_jobs == 1:
|
if n_jobs == 1:
|
||||||
return {lang: metrics(ly_true[lang], ly_pred[lang]) for lang in ly_true.keys()}
|
return {
|
||||||
|
lang: metrics(ly_true[lang], ly_pred[lang], clf_type)
|
||||||
|
for lang in ly_true.keys()
|
||||||
|
}
|
||||||
else:
|
else:
|
||||||
langs = list(ly_true.keys())
|
langs = list(ly_true.keys())
|
||||||
evals = Parallel(n_jobs=n_jobs)(
|
evals = Parallel(n_jobs=n_jobs)(
|
||||||
delayed(metrics)(ly_true[lang], ly_pred[lang]) for lang in langs
|
delayed(metrics)(ly_true[lang], ly_pred[lang], clf_type) for lang in langs
|
||||||
)
|
)
|
||||||
return {lang: evals[i] for i, lang in enumerate(langs)}
|
return {lang: evals[i] for i, lang in enumerate(langs)}
|
||||||
|
|
||||||
|
|
||||||
def log_eval(l_eval, phase="training", verbose=True):
|
def log_eval(l_eval, phase="training", clf_type="multilabel", verbose=True):
|
||||||
if verbose:
|
if verbose:
|
||||||
print(f"\n[Results {phase}]")
|
print(f"\n[Results {phase}]")
|
||||||
metrics = []
|
metrics = []
|
||||||
for lang in l_eval.keys():
|
|
||||||
macrof1, microf1, macrok, microk = l_eval[lang]
|
if clf_type == "multilabel":
|
||||||
metrics.append([macrof1, microf1, macrok, microk])
|
for lang in l_eval.keys():
|
||||||
if phase != "validation":
|
macrof1, microf1, macrok, microk = l_eval[lang]
|
||||||
print(f"Lang {lang}: macro-F1 = {macrof1:.3f} micro-F1 = {microf1:.3f}")
|
metrics.append([macrof1, microf1, macrok, microk])
|
||||||
averages = np.mean(np.array(metrics), axis=0)
|
if phase != "validation":
|
||||||
if verbose:
|
print(f"Lang {lang}: macro-F1 = {macrof1:.3f} micro-F1 = {microf1:.3f}")
|
||||||
print(
|
averages = np.mean(np.array(metrics), axis=0)
|
||||||
"Averages: MF1, mF1, MK, mK",
|
if verbose:
|
||||||
np.round(averages, 3),
|
print(
|
||||||
"\n",
|
"Averages: MF1, mF1, MK, mK",
|
||||||
)
|
np.round(averages, 3),
|
||||||
return averages
|
"\n",
|
||||||
|
)
|
||||||
|
return averages # TODO: return a dict avg and lang specific
|
||||||
|
|
||||||
|
elif clf_type == "singlelabel":
|
||||||
|
lang_metrics = defaultdict(dict)
|
||||||
|
_metrics = [
|
||||||
|
"accuracy",
|
||||||
|
# "acc5", # "accuracy-at-5",
|
||||||
|
# "acc10", # "accuracy-at-10",
|
||||||
|
"MF1", # "macro-F1",
|
||||||
|
"mF1", # "micro-F1",
|
||||||
|
]
|
||||||
|
for lang in l_eval.keys():
|
||||||
|
# acc, top5, top10, macrof1, microf1 = l_eval[lang]
|
||||||
|
acc, macrof1, microf1 = l_eval[lang]
|
||||||
|
# metrics.append([acc, top5, top10, macrof1, microf1])
|
||||||
|
metrics.append([acc, macrof1, microf1])
|
||||||
|
|
||||||
|
for m, v in zip(_metrics, l_eval[lang]):
|
||||||
|
lang_metrics[m][lang] = v
|
||||||
|
|
||||||
|
if phase != "validation":
|
||||||
|
print(
|
||||||
|
# f"Lang {lang}: acc = {acc:.3f} acc-top5 = {top5:.3f} acc-top10 = {top10:.3f} macro-F1: {macrof1:.3f} micro-F1 = {microf1:.3f}"
|
||||||
|
f"Lang {lang}: acc = {acc:.3f} macro-F1: {macrof1:.3f} micro-F1 = {microf1:.3f}"
|
||||||
|
)
|
||||||
|
averages = np.mean(np.array(metrics), axis=0)
|
||||||
|
if verbose:
|
||||||
|
print(
|
||||||
|
# "Averages: Acc, Acc-5, Acc-10, MF1, mF1",
|
||||||
|
"Averages: Acc, MF1, mF1",
|
||||||
|
np.round(averages, 3),
|
||||||
|
"\n",
|
||||||
|
)
|
||||||
|
avg_metrics = dict(zip(_metrics, averages))
|
||||||
|
return avg_metrics, lang_metrics
|
||||||
|
|
|
@ -1,17 +1,14 @@
|
||||||
import os
|
import os
|
||||||
import sys
|
|
||||||
|
|
||||||
# sys.path.append(os.path.join(os.getcwd(), "gfun"))
|
|
||||||
|
|
||||||
import pickle
|
import pickle
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from gfun.vgfs.commons import TfidfVectorizerMultilingual, AttentionAggregator
|
|
||||||
|
from gfun.vgfs.commons import AttentionAggregator, TfidfVectorizerMultilingual, predict
|
||||||
from gfun.vgfs.learners.svms import MetaClassifier, get_learner
|
from gfun.vgfs.learners.svms import MetaClassifier, get_learner
|
||||||
from gfun.vgfs.multilingualGen import MultilingualGen
|
from gfun.vgfs.multilingualGen import MultilingualGen
|
||||||
from gfun.vgfs.textualTransformerGen import TextualTransformerGen
|
from gfun.vgfs.textualTransformerGen import TextualTransformerGen
|
||||||
from gfun.vgfs.visualTransformerGen import VisualTransformerGen
|
|
||||||
from gfun.vgfs.vanillaFun import VanillaFunGen
|
from gfun.vgfs.vanillaFun import VanillaFunGen
|
||||||
|
from gfun.vgfs.visualTransformerGen import VisualTransformerGen
|
||||||
from gfun.vgfs.wceGen import WceGen
|
from gfun.vgfs.wceGen import WceGen
|
||||||
|
|
||||||
|
|
||||||
|
@ -25,12 +22,14 @@ class GeneralizedFunnelling:
|
||||||
visual_transformer,
|
visual_transformer,
|
||||||
langs,
|
langs,
|
||||||
num_labels,
|
num_labels,
|
||||||
|
classification_type,
|
||||||
embed_dir,
|
embed_dir,
|
||||||
n_jobs,
|
n_jobs,
|
||||||
batch_size,
|
batch_size,
|
||||||
eval_batch_size,
|
eval_batch_size,
|
||||||
max_length,
|
max_length,
|
||||||
lr,
|
textual_lr,
|
||||||
|
visual_lr,
|
||||||
epochs,
|
epochs,
|
||||||
patience,
|
patience,
|
||||||
evaluate_step,
|
evaluate_step,
|
||||||
|
@ -52,6 +51,7 @@ class GeneralizedFunnelling:
|
||||||
self.visual_trf_vgf = visual_transformer
|
self.visual_trf_vgf = visual_transformer
|
||||||
self.probabilistic = probabilistic
|
self.probabilistic = probabilistic
|
||||||
self.num_labels = num_labels
|
self.num_labels = num_labels
|
||||||
|
self.clf_type = classification_type
|
||||||
# ------------------------
|
# ------------------------
|
||||||
self.langs = langs
|
self.langs = langs
|
||||||
self.embed_dir = embed_dir
|
self.embed_dir = embed_dir
|
||||||
|
@ -59,7 +59,8 @@ class GeneralizedFunnelling:
|
||||||
# Textual Transformer VGF params ----------
|
# Textual Transformer VGF params ----------
|
||||||
self.textual_trf_name = textual_transformer_name
|
self.textual_trf_name = textual_transformer_name
|
||||||
self.epochs = epochs
|
self.epochs = epochs
|
||||||
self.lr_transformer = lr
|
self.txt_trf_lr = textual_lr
|
||||||
|
self.vis_trf_lr = visual_lr
|
||||||
self.batch_size_trf = batch_size
|
self.batch_size_trf = batch_size
|
||||||
self.eval_batch_size_trf = eval_batch_size
|
self.eval_batch_size_trf = eval_batch_size
|
||||||
self.max_length = max_length
|
self.max_length = max_length
|
||||||
|
@ -114,7 +115,7 @@ class GeneralizedFunnelling:
|
||||||
self.attn_aggregator = AttentionAggregator(
|
self.attn_aggregator = AttentionAggregator(
|
||||||
embed_dim=self.get_attn_agg_dim(attn_stacking_type=attn_stacking),
|
embed_dim=self.get_attn_agg_dim(attn_stacking_type=attn_stacking),
|
||||||
out_dim=self.num_labels,
|
out_dim=self.num_labels,
|
||||||
lr=self.lr_transformer,
|
lr=self.txt_trf_lr,
|
||||||
patience=self.patience,
|
patience=self.patience,
|
||||||
num_heads=1,
|
num_heads=1,
|
||||||
device=self.device,
|
device=self.device,
|
||||||
|
@ -148,7 +149,7 @@ class GeneralizedFunnelling:
|
||||||
transformer_vgf = TextualTransformerGen(
|
transformer_vgf = TextualTransformerGen(
|
||||||
dataset_name=self.dataset_name,
|
dataset_name=self.dataset_name,
|
||||||
model_name=self.textual_trf_name,
|
model_name=self.textual_trf_name,
|
||||||
lr=self.lr_transformer,
|
lr=self.txt_trf_lr,
|
||||||
epochs=self.epochs,
|
epochs=self.epochs,
|
||||||
batch_size=self.batch_size_trf,
|
batch_size=self.batch_size_trf,
|
||||||
batch_size_eval=self.eval_batch_size_trf,
|
batch_size_eval=self.eval_batch_size_trf,
|
||||||
|
@ -159,6 +160,7 @@ class GeneralizedFunnelling:
|
||||||
verbose=True,
|
verbose=True,
|
||||||
patience=self.patience,
|
patience=self.patience,
|
||||||
device=self.device,
|
device=self.device,
|
||||||
|
classification_type=self.clf_type,
|
||||||
)
|
)
|
||||||
self.first_tier_learners.append(transformer_vgf)
|
self.first_tier_learners.append(transformer_vgf)
|
||||||
|
|
||||||
|
@ -166,7 +168,7 @@ class GeneralizedFunnelling:
|
||||||
visual_trasformer_vgf = VisualTransformerGen(
|
visual_trasformer_vgf = VisualTransformerGen(
|
||||||
dataset_name=self.dataset_name,
|
dataset_name=self.dataset_name,
|
||||||
model_name="vit",
|
model_name="vit",
|
||||||
lr=self.lr_transformer,
|
lr=self.vis_trf_lr,
|
||||||
epochs=self.epochs,
|
epochs=self.epochs,
|
||||||
batch_size=self.batch_size_trf,
|
batch_size=self.batch_size_trf,
|
||||||
batch_size_eval=self.eval_batch_size_trf,
|
batch_size_eval=self.eval_batch_size_trf,
|
||||||
|
@ -174,6 +176,7 @@ class GeneralizedFunnelling:
|
||||||
evaluate_step=self.evaluate_step,
|
evaluate_step=self.evaluate_step,
|
||||||
patience=self.patience,
|
patience=self.patience,
|
||||||
device=self.device,
|
device=self.device,
|
||||||
|
classification_type=self.clf_type,
|
||||||
)
|
)
|
||||||
self.first_tier_learners.append(visual_trasformer_vgf)
|
self.first_tier_learners.append(visual_trasformer_vgf)
|
||||||
|
|
||||||
|
@ -182,7 +185,7 @@ class GeneralizedFunnelling:
|
||||||
self.attn_aggregator = AttentionAggregator(
|
self.attn_aggregator = AttentionAggregator(
|
||||||
embed_dim=self.get_attn_agg_dim(attn_stacking_type=attn_stacking),
|
embed_dim=self.get_attn_agg_dim(attn_stacking_type=attn_stacking),
|
||||||
out_dim=self.num_labels,
|
out_dim=self.num_labels,
|
||||||
lr=self.lr_transformer,
|
lr=self.txt_trf_lr,
|
||||||
patience=self.patience,
|
patience=self.patience,
|
||||||
num_heads=1,
|
num_heads=1,
|
||||||
device=self.device,
|
device=self.device,
|
||||||
|
@ -255,10 +258,9 @@ class GeneralizedFunnelling:
|
||||||
projections.append(l_posteriors)
|
projections.append(l_posteriors)
|
||||||
agg = self.aggregate(projections)
|
agg = self.aggregate(projections)
|
||||||
l_out = self.metaclassifier.predict_proba(agg)
|
l_out = self.metaclassifier.predict_proba(agg)
|
||||||
# converting to binary predictions
|
if self.clf_type == "singlelabel":
|
||||||
# if self.dataset_name in ["cls"]: # TODO: better way to do this
|
for lang, preds in l_out.items():
|
||||||
# for lang, preds in l_out.items():
|
l_out[lang] = predict(preds, clf_type=self.clf_type)
|
||||||
# l_out[lang] = np.expand_dims(np.argmax(preds, axis=1), axis=1)
|
|
||||||
return l_out
|
return l_out
|
||||||
|
|
||||||
def fit_transform(self, lX, lY):
|
def fit_transform(self, lX, lY):
|
||||||
|
|
|
@ -9,6 +9,7 @@ from sklearn.feature_extraction.text import TfidfVectorizer
|
||||||
from sklearn.model_selection import train_test_split
|
from sklearn.model_selection import train_test_split
|
||||||
from sklearn.preprocessing import normalize
|
from sklearn.preprocessing import normalize
|
||||||
from torch.optim import AdamW
|
from torch.optim import AdamW
|
||||||
|
from torch.optim.lr_scheduler import ReduceLROnPlateau
|
||||||
from torch.utils.data import DataLoader, Dataset
|
from torch.utils.data import DataLoader, Dataset
|
||||||
from transformers.modeling_outputs import ModelOutput
|
from transformers.modeling_outputs import ModelOutput
|
||||||
|
|
||||||
|
@ -22,6 +23,21 @@ def _normalize(lX, l2=True):
|
||||||
return {lang: normalize(np.asarray(X)) for lang, X in lX.items()} if l2 else lX
|
return {lang: normalize(np.asarray(X)) for lang, X in lX.items()} if l2 else lX
|
||||||
|
|
||||||
|
|
||||||
|
def verbosity_eval(epoch, print_eval):
|
||||||
|
if (epoch + 1) % print_eval == 0 and epoch != 0:
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def format_langkey_wandb(lang_dict):
|
||||||
|
log_dict = {}
|
||||||
|
for metric, l_dict in lang_dict.items():
|
||||||
|
for lang, value in l_dict.items():
|
||||||
|
log_dict[f"language metric/{metric}/{lang}"] = value
|
||||||
|
return log_dict
|
||||||
|
|
||||||
|
|
||||||
def XdotM(X, M, sif):
|
def XdotM(X, M, sif):
|
||||||
E = X.dot(M)
|
E = X.dot(M)
|
||||||
if sif:
|
if sif:
|
||||||
|
@ -58,18 +74,23 @@ def compute_pc(X, npc=1):
|
||||||
return svd.components_
|
return svd.components_
|
||||||
|
|
||||||
|
|
||||||
def predict(logits, classification_type="multilabel"):
|
def predict(logits, clf_type="multilabel"):
|
||||||
"""
|
"""
|
||||||
Converts soft precictions to hard predictions [0,1]
|
Converts soft precictions to hard predictions [0,1]
|
||||||
"""
|
"""
|
||||||
if classification_type == "multilabel":
|
if clf_type == "multilabel":
|
||||||
prediction = torch.sigmoid(logits) > 0.5
|
prediction = torch.sigmoid(logits) > 0.5
|
||||||
elif classification_type == "singlelabel":
|
return prediction.detach().cpu().numpy()
|
||||||
prediction = torch.argmax(logits, dim=1).view(-1, 1)
|
elif clf_type == "singlelabel":
|
||||||
|
if type(logits) != torch.Tensor:
|
||||||
|
logits = torch.tensor(logits)
|
||||||
|
prediction = torch.softmax(logits, dim=1)
|
||||||
|
prediction = prediction.detach().cpu().numpy()
|
||||||
|
_argmaxs = prediction.argmax(axis=1)
|
||||||
|
prediction = np.eye(prediction.shape[1])[_argmaxs]
|
||||||
|
return prediction
|
||||||
else:
|
else:
|
||||||
print("unknown classification type")
|
raise NotImplementedError()
|
||||||
|
|
||||||
return prediction.detach().cpu().numpy()
|
|
||||||
|
|
||||||
|
|
||||||
class TfidfVectorizerMultilingual:
|
class TfidfVectorizerMultilingual:
|
||||||
|
@ -115,36 +136,54 @@ class Trainer:
|
||||||
patience,
|
patience,
|
||||||
experiment_name,
|
experiment_name,
|
||||||
checkpoint_path,
|
checkpoint_path,
|
||||||
|
classification_type,
|
||||||
vgf_name,
|
vgf_name,
|
||||||
|
n_jobs,
|
||||||
|
scheduler_name=None,
|
||||||
):
|
):
|
||||||
self.device = device
|
self.device = device
|
||||||
self.model = model.to(device)
|
self.model = model.to(device)
|
||||||
self.optimizer = self.init_optimizer(optimizer_name, lr)
|
self.optimizer, self.scheduler = self.init_optimizer(
|
||||||
|
optimizer_name, lr, scheduler_name
|
||||||
|
)
|
||||||
self.evaluate_steps = evaluate_step
|
self.evaluate_steps = evaluate_step
|
||||||
self.loss_fn = loss_fn.to(device)
|
self.loss_fn = loss_fn.to(device)
|
||||||
self.print_steps = print_steps
|
self.print_steps = print_steps
|
||||||
self.experiment_name = experiment_name
|
self.experiment_name = experiment_name
|
||||||
self.patience = patience
|
self.patience = patience
|
||||||
self.print_eval = evaluate_step
|
self.print_eval = 10
|
||||||
self.earlystopping = EarlyStopping(
|
self.earlystopping = EarlyStopping(
|
||||||
patience=patience,
|
patience=patience,
|
||||||
checkpoint_path=checkpoint_path,
|
checkpoint_path=checkpoint_path,
|
||||||
verbose=False,
|
verbose=False,
|
||||||
experiment_name=experiment_name,
|
experiment_name=experiment_name,
|
||||||
)
|
)
|
||||||
|
self.clf_type = classification_type
|
||||||
self.vgf_name = vgf_name
|
self.vgf_name = vgf_name
|
||||||
|
self.n_jobs = n_jobs
|
||||||
|
self.monitored_metric = (
|
||||||
|
"macro-F1" if self.clf_type == "multilabel" else "accuracy"
|
||||||
|
) # TODO: make this configurable
|
||||||
|
|
||||||
def init_optimizer(self, optimizer_name, lr):
|
def init_optimizer(self, optimizer_name, lr, scheduler_name):
|
||||||
if optimizer_name.lower() == "adamw":
|
if optimizer_name.lower() == "adamw":
|
||||||
return AdamW(self.model.parameters(), lr=lr)
|
optim = AdamW(self.model.parameters(), lr=lr)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Optimizer {optimizer_name} not supported")
|
raise ValueError(f"Optimizer {optimizer_name} not supported")
|
||||||
|
if scheduler_name is None:
|
||||||
|
scheduler = None
|
||||||
|
elif scheduler_name == "ReduceLROnPlateau":
|
||||||
|
scheduler = ReduceLROnPlateau(optim, "max", factor=0.5, min_lr=1e-5)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Scheduler {scheduler_name} not supported")
|
||||||
|
return optim, scheduler
|
||||||
|
|
||||||
def get_config(self, train_dataloader, eval_dataloader, epochs):
|
def get_config(self, train_dataloader, eval_dataloader, epochs):
|
||||||
return {
|
return {
|
||||||
"model name": self.model.name_or_path,
|
"model name": self.model.name_or_path,
|
||||||
"epochs": epochs,
|
"epochs": epochs,
|
||||||
"learning rate": self.optimizer.defaults["lr"],
|
"learning rate": self.optimizer.defaults["lr"],
|
||||||
|
"scheduler": "TODO", # TODO: add scheduler name
|
||||||
"train batch size": train_dataloader.batch_size,
|
"train batch size": train_dataloader.batch_size,
|
||||||
"eval batch size": eval_dataloader.batch_size,
|
"eval batch size": eval_dataloader.batch_size,
|
||||||
"max len": train_dataloader.dataset.X.shape[-1],
|
"max len": train_dataloader.dataset.X.shape[-1],
|
||||||
|
@ -152,6 +191,7 @@ class Trainer:
|
||||||
"evaluate every": self.evaluate_steps,
|
"evaluate every": self.evaluate_steps,
|
||||||
"print eval every": self.print_eval,
|
"print eval every": self.print_eval,
|
||||||
"print train steps": self.print_steps,
|
"print train steps": self.print_steps,
|
||||||
|
"classification type": self.clf_type,
|
||||||
}
|
}
|
||||||
|
|
||||||
def train(self, train_dataloader, eval_dataloader, epochs=10):
|
def train(self, train_dataloader, eval_dataloader, epochs=10):
|
||||||
|
@ -168,23 +208,23 @@ class Trainer:
|
||||||
for epoch in range(epochs):
|
for epoch in range(epochs):
|
||||||
train_loss = self.train_epoch(train_dataloader, epoch)
|
train_loss = self.train_epoch(train_dataloader, epoch)
|
||||||
|
|
||||||
wandb_logger.log({f"{self.vgf_name}_train_loss": train_loss})
|
if (epoch + 1) % self.evaluate_steps == 0 or (epoch + 1) == 1:
|
||||||
|
print_eval = verbosity_eval(epoch, self.print_eval)
|
||||||
if (epoch + 1) % self.evaluate_steps == 0:
|
|
||||||
print_eval = (epoch + 1) % self.print_eval == 0
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
eval_loss, metric_watcher = self.evaluate(
|
eval_loss, avg_metrics, lang_metrics = self.evaluate(
|
||||||
eval_dataloader, epoch, print_eval=print_eval
|
eval_dataloader,
|
||||||
|
print_eval=print_eval,
|
||||||
|
n_jobs=self.n_jobs,
|
||||||
)
|
)
|
||||||
|
|
||||||
wandb_logger.log(
|
wandb_logger.log(
|
||||||
{
|
{"loss/val": eval_loss, **format_langkey_wandb(lang_metrics)},
|
||||||
f"{self.vgf_name}_eval_loss": eval_loss,
|
commit=False,
|
||||||
f"{self.vgf_name}_eval_metric": metric_watcher,
|
|
||||||
}
|
|
||||||
)
|
)
|
||||||
|
|
||||||
stop = self.earlystopping(metric_watcher, self.model, epoch + 1)
|
stop = self.earlystopping(
|
||||||
|
avg_metrics[self.monitored_metric], self.model, epoch + 1
|
||||||
|
)
|
||||||
if stop:
|
if stop:
|
||||||
print(
|
print(
|
||||||
f"- restoring best model from epoch {self.earlystopping.best_epoch} with best metric: {self.earlystopping.best_score:3f}"
|
f"- restoring best model from epoch {self.earlystopping.best_epoch} with best metric: {self.earlystopping.best_score:3f}"
|
||||||
|
@ -194,6 +234,16 @@ class Trainer:
|
||||||
)
|
)
|
||||||
break
|
break
|
||||||
|
|
||||||
|
if self.scheduler is not None:
|
||||||
|
self.scheduler.step(avg_metrics[self.monitored_metric])
|
||||||
|
|
||||||
|
wandb_logger.log(
|
||||||
|
{
|
||||||
|
"loss/train": train_loss,
|
||||||
|
"learning rate": self.optimizer.param_groups[0]["lr"],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
print(f"- last swipe on eval set")
|
print(f"- last swipe on eval set")
|
||||||
self.train_epoch(eval_dataloader, epoch=-1)
|
self.train_epoch(eval_dataloader, epoch=-1)
|
||||||
self.earlystopping.save_model(self.model)
|
self.earlystopping.save_model(self.model)
|
||||||
|
@ -201,6 +251,7 @@ class Trainer:
|
||||||
|
|
||||||
def train_epoch(self, dataloader, epoch):
|
def train_epoch(self, dataloader, epoch):
|
||||||
self.model.train()
|
self.model.train()
|
||||||
|
epoch_losses = []
|
||||||
for b_idx, (x, y, lang) in enumerate(dataloader):
|
for b_idx, (x, y, lang) in enumerate(dataloader):
|
||||||
self.optimizer.zero_grad()
|
self.optimizer.zero_grad()
|
||||||
y_hat = self.model(x.to(self.device))
|
y_hat = self.model(x.to(self.device))
|
||||||
|
@ -210,38 +261,47 @@ class Trainer:
|
||||||
loss = self.loss_fn(y_hat, y.to(self.device))
|
loss = self.loss_fn(y_hat, y.to(self.device))
|
||||||
loss.backward()
|
loss.backward()
|
||||||
self.optimizer.step()
|
self.optimizer.step()
|
||||||
|
epoch_losses.append(loss.item())
|
||||||
if (epoch + 1) % PRINT_ON_EPOCH == 0:
|
if (epoch + 1) % PRINT_ON_EPOCH == 0:
|
||||||
if ((b_idx + 1) % self.print_steps == 0) or b_idx == 0:
|
if ((b_idx + 1) % self.print_steps == 0) or b_idx == 0:
|
||||||
print(f"Epoch: {epoch+1} Step: {b_idx+1} Loss: {loss:.4f}")
|
print(
|
||||||
return loss.item()
|
f"Epoch: {epoch+1} Step: {b_idx+1} Loss: {np.mean(epoch_losses):.4f}"
|
||||||
|
)
|
||||||
|
return np.mean(epoch_losses)
|
||||||
|
|
||||||
def evaluate(self, dataloader, epoch, print_eval=True):
|
def evaluate(self, dataloader, print_eval=True, n_jobs=-1):
|
||||||
self.model.eval()
|
self.model.eval()
|
||||||
|
eval_losses = []
|
||||||
|
|
||||||
lY = defaultdict(list)
|
lY_true = defaultdict(list)
|
||||||
lY_hat = defaultdict(list)
|
lY_pred = defaultdict(list)
|
||||||
|
|
||||||
for b_idx, (x, y, lang) in enumerate(dataloader):
|
for b_idx, (x, y, lang) in enumerate(dataloader):
|
||||||
y_hat = self.model(x.to(self.device))
|
y_pred = self.model(x.to(self.device))
|
||||||
if isinstance(y_hat, ModelOutput):
|
if isinstance(y_pred, ModelOutput):
|
||||||
loss = self.loss_fn(y_hat.logits, y.to(self.device))
|
loss = self.loss_fn(y_pred.logits, y.to(self.device))
|
||||||
predictions = predict(y_hat.logits, classification_type="multilabel")
|
predictions = predict(y_pred.logits, clf_type=self.clf_type)
|
||||||
else:
|
else:
|
||||||
loss = self.loss_fn(y_hat, y.to(self.device))
|
loss = self.loss_fn(y_pred, y.to(self.device))
|
||||||
predictions = predict(y_hat, classification_type="multilabel")
|
predictions = predict(y_pred, clf_type=self.clf_type)
|
||||||
|
|
||||||
|
eval_losses.append(loss.item())
|
||||||
|
|
||||||
for l, _true, _pred in zip(lang, y, predictions):
|
for l, _true, _pred in zip(lang, y, predictions):
|
||||||
lY[l].append(_true.detach().cpu().numpy())
|
lY_true[l].append(_true.detach().cpu().numpy())
|
||||||
lY_hat[l].append(_pred)
|
lY_pred[l].append(_pred)
|
||||||
|
|
||||||
for lang in lY:
|
for lang in lY_true:
|
||||||
lY[lang] = np.vstack(lY[lang])
|
lY_true[lang] = np.vstack(lY_true[lang])
|
||||||
lY_hat[lang] = np.vstack(lY_hat[lang])
|
lY_pred[lang] = np.vstack(lY_pred[lang])
|
||||||
|
|
||||||
l_eval = evaluate(lY, lY_hat)
|
l_eval = evaluate(lY_true, lY_pred, clf_type=self.clf_type, n_jobs=n_jobs)
|
||||||
average_metrics = log_eval(l_eval, phase="validation", verbose=print_eval)
|
|
||||||
|
|
||||||
return loss.item(), average_metrics[0] # macro-F1
|
avg_metrics, lang_metrics = log_eval(
|
||||||
|
l_eval, phase="validation", clf_type=self.clf_type, verbose=print_eval
|
||||||
|
)
|
||||||
|
|
||||||
|
return np.mean(eval_losses), avg_metrics, lang_metrics
|
||||||
|
|
||||||
|
|
||||||
class EarlyStopping:
|
class EarlyStopping:
|
||||||
|
@ -279,7 +339,7 @@ class EarlyStopping:
|
||||||
print(
|
print(
|
||||||
f"- earlystopping: Validation score decreased from {self.best_score:.3f} to {validation:.3f}, current patience: {self.patience - self.counter}"
|
f"- earlystopping: Validation score decreased from {self.best_score:.3f} to {validation:.3f}, current patience: {self.patience - self.counter}"
|
||||||
)
|
)
|
||||||
if self.counter >= self.patience:
|
if self.counter >= self.patience and self.patience != -1:
|
||||||
print(f"- earlystopping: Early stopping at epoch {epoch}")
|
print(f"- earlystopping: Early stopping at epoch {epoch}")
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
|
@ -7,11 +7,9 @@ from collections import defaultdict
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import transformers
|
import transformers
|
||||||
|
|
||||||
# from sklearn.model_selection import train_test_split
|
|
||||||
# from torch.optim import AdamW
|
|
||||||
from torch.utils.data import Dataset
|
from torch.utils.data import Dataset
|
||||||
from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
||||||
|
|
||||||
from gfun.vgfs.commons import Trainer
|
from gfun.vgfs.commons import Trainer
|
||||||
from gfun.vgfs.transformerGen import TransformerGen
|
from gfun.vgfs.transformerGen import TransformerGen
|
||||||
from gfun.vgfs.viewGen import ViewGen
|
from gfun.vgfs.viewGen import ViewGen
|
||||||
|
@ -19,9 +17,6 @@ from gfun.vgfs.viewGen import ViewGen
|
||||||
transformers.logging.set_verbosity_error()
|
transformers.logging.set_verbosity_error()
|
||||||
|
|
||||||
|
|
||||||
# TODO: add support to loggers
|
|
||||||
|
|
||||||
|
|
||||||
class TextualTransformerGen(ViewGen, TransformerGen):
|
class TextualTransformerGen(ViewGen, TransformerGen):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -39,6 +34,7 @@ class TextualTransformerGen(ViewGen, TransformerGen):
|
||||||
evaluate_step=10,
|
evaluate_step=10,
|
||||||
verbose=False,
|
verbose=False,
|
||||||
patience=5,
|
patience=5,
|
||||||
|
classification_type="multilabel",
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
self._validate_model_name(model_name),
|
self._validate_model_name(model_name),
|
||||||
|
@ -56,6 +52,7 @@ class TextualTransformerGen(ViewGen, TransformerGen):
|
||||||
verbose,
|
verbose,
|
||||||
patience,
|
patience,
|
||||||
)
|
)
|
||||||
|
self.clf_type = classification_type
|
||||||
self.fitted = False
|
self.fitted = False
|
||||||
print(
|
print(
|
||||||
f"- init Textual TransformerModel model_name: {self.model_name}, device: {self.device}]"
|
f"- init Textual TransformerModel model_name: {self.model_name}, device: {self.device}]"
|
||||||
|
@ -143,6 +140,9 @@ class TextualTransformerGen(ViewGen, TransformerGen):
|
||||||
experiment_name=experiment_name,
|
experiment_name=experiment_name,
|
||||||
checkpoint_path="models/vgfs/transformer",
|
checkpoint_path="models/vgfs/transformer",
|
||||||
vgf_name="textual_trf",
|
vgf_name="textual_trf",
|
||||||
|
classification_type=self.clf_type,
|
||||||
|
n_jobs=self.n_jobs,
|
||||||
|
scheduler_name="ReduceLROnPlateau",
|
||||||
)
|
)
|
||||||
trainer.train(
|
trainer.train(
|
||||||
train_dataloader=tra_dataloader,
|
train_dataloader=tra_dataloader,
|
||||||
|
|
|
@ -27,6 +27,7 @@ class VisualTransformerGen(ViewGen, TransformerGen):
|
||||||
device="cpu",
|
device="cpu",
|
||||||
probabilistic=False,
|
probabilistic=False,
|
||||||
patience=5,
|
patience=5,
|
||||||
|
classification_type="multilabel",
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
model_name,
|
model_name,
|
||||||
|
@ -40,6 +41,7 @@ class VisualTransformerGen(ViewGen, TransformerGen):
|
||||||
patience=patience,
|
patience=patience,
|
||||||
probabilistic=probabilistic,
|
probabilistic=probabilistic,
|
||||||
)
|
)
|
||||||
|
self.clf_type = classification_type
|
||||||
self.fitted = False
|
self.fitted = False
|
||||||
print(
|
print(
|
||||||
f"- init Visual TransformerModel model_name: {self.model_name}, device: {self.device}]"
|
f"- init Visual TransformerModel model_name: {self.model_name}, device: {self.device}]"
|
||||||
|
@ -113,6 +115,8 @@ class VisualTransformerGen(ViewGen, TransformerGen):
|
||||||
experiment_name=experiment_name,
|
experiment_name=experiment_name,
|
||||||
checkpoint_path="models/vgfs/transformer",
|
checkpoint_path="models/vgfs/transformer",
|
||||||
vgf_name="visual_trf",
|
vgf_name="visual_trf",
|
||||||
|
classification_type=self.clf_type,
|
||||||
|
n_jobs=self.n_jobs,
|
||||||
)
|
)
|
||||||
|
|
||||||
trainer.train(
|
trainer.train(
|
||||||
|
|
17
main.py
17
main.py
|
@ -1,3 +1,7 @@
|
||||||
|
import os
|
||||||
|
|
||||||
|
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
|
||||||
|
|
||||||
from argparse import ArgumentParser
|
from argparse import ArgumentParser
|
||||||
from time import time
|
from time import time
|
||||||
|
|
||||||
|
@ -7,6 +11,7 @@ from gfun.generalizedFunnelling import GeneralizedFunnelling
|
||||||
|
|
||||||
"""
|
"""
|
||||||
TODO:
|
TODO:
|
||||||
|
- [!] LR scheduler
|
||||||
- [!] CLS dataset is loading only "books" domain data
|
- [!] CLS dataset is loading only "books" domain data
|
||||||
- [!] documents should be trimmed to the same length (?)
|
- [!] documents should be trimmed to the same length (?)
|
||||||
- [!] overall gfun results logger
|
- [!] overall gfun results logger
|
||||||
|
@ -42,6 +47,7 @@ def main(args):
|
||||||
dataset_name=args.dataset,
|
dataset_name=args.dataset,
|
||||||
langs=dataset.langs(),
|
langs=dataset.langs(),
|
||||||
num_labels=dataset.num_labels(),
|
num_labels=dataset.num_labels(),
|
||||||
|
classification_type=args.clf_type,
|
||||||
# Posterior VGF params ----------------
|
# Posterior VGF params ----------------
|
||||||
posterior=args.posteriors,
|
posterior=args.posteriors,
|
||||||
# Multilingual VGF params -------------
|
# Multilingual VGF params -------------
|
||||||
|
@ -55,7 +61,8 @@ def main(args):
|
||||||
batch_size=args.batch_size,
|
batch_size=args.batch_size,
|
||||||
eval_batch_size=args.eval_batch_size,
|
eval_batch_size=args.eval_batch_size,
|
||||||
epochs=args.epochs,
|
epochs=args.epochs,
|
||||||
lr=args.lr,
|
textual_lr=args.textual_lr,
|
||||||
|
visual_lr=args.visual_lr,
|
||||||
max_length=args.max_length,
|
max_length=args.max_length,
|
||||||
patience=args.patience,
|
patience=args.patience,
|
||||||
evaluate_step=args.evaluate_step,
|
evaluate_step=args.evaluate_step,
|
||||||
|
@ -93,8 +100,8 @@ def main(args):
|
||||||
print(f"- training completed in {timetr - tinit:.2f} seconds")
|
print(f"- training completed in {timetr - tinit:.2f} seconds")
|
||||||
|
|
||||||
gfun_preds = gfun.transform(lX_te)
|
gfun_preds = gfun.transform(lX_te)
|
||||||
test_eval = evaluate(lY_te, gfun_preds)
|
test_eval = evaluate(lY_te, gfun_preds, clf_type=args.clf_type, n_jobs=args.n_jobs)
|
||||||
log_eval(test_eval, phase="test")
|
log_eval(test_eval, phase="test", clf_type=args.clf_type)
|
||||||
|
|
||||||
timeval = time()
|
timeval = time()
|
||||||
print(f"- testing completed in {timeval - timetr:.2f} seconds")
|
print(f"- testing completed in {timeval - timetr:.2f} seconds")
|
||||||
|
@ -112,6 +119,7 @@ if __name__ == "__main__":
|
||||||
parser.add_argument("--nrows", type=int, default=None)
|
parser.add_argument("--nrows", type=int, default=None)
|
||||||
parser.add_argument("--min_count", type=int, default=10)
|
parser.add_argument("--min_count", type=int, default=10)
|
||||||
parser.add_argument("--max_labels", type=int, default=50)
|
parser.add_argument("--max_labels", type=int, default=50)
|
||||||
|
parser.add_argument("--clf_type", type=str, default="multilabel")
|
||||||
# gFUN parameters ----------------------
|
# gFUN parameters ----------------------
|
||||||
parser.add_argument("-p", "--posteriors", action="store_true")
|
parser.add_argument("-p", "--posteriors", action="store_true")
|
||||||
parser.add_argument("-m", "--multilingual", action="store_true")
|
parser.add_argument("-m", "--multilingual", action="store_true")
|
||||||
|
@ -127,7 +135,8 @@ if __name__ == "__main__":
|
||||||
parser.add_argument("--batch_size", type=int, default=32)
|
parser.add_argument("--batch_size", type=int, default=32)
|
||||||
parser.add_argument("--eval_batch_size", type=int, default=128)
|
parser.add_argument("--eval_batch_size", type=int, default=128)
|
||||||
parser.add_argument("--epochs", type=int, default=100)
|
parser.add_argument("--epochs", type=int, default=100)
|
||||||
parser.add_argument("--lr", type=float, default=1e-5)
|
parser.add_argument("--textual_lr", type=float, default=1e-5)
|
||||||
|
parser.add_argument("--visual_lr", type=float, default=1e-5)
|
||||||
parser.add_argument("--max_length", type=int, default=128)
|
parser.add_argument("--max_length", type=int, default=128)
|
||||||
parser.add_argument("--patience", type=int, default=5)
|
parser.add_argument("--patience", type=int, default=5)
|
||||||
parser.add_argument("--evaluate_step", type=int, default=10)
|
parser.add_argument("--evaluate_step", type=int, default=10)
|
||||||
|
|
Loading…
Reference in New Issue