589 lines
19 KiB
Python
589 lines
19 KiB
Python
import os
|
|
from collections import defaultdict
|
|
|
|
import numpy as np
|
|
import torch
|
|
import torch.nn as nn
|
|
from sklearn.decomposition import TruncatedSVD
|
|
from sklearn.feature_extraction.text import TfidfVectorizer
|
|
from sklearn.model_selection import train_test_split
|
|
from sklearn.preprocessing import normalize
|
|
from torch.optim import AdamW
|
|
from torch.optim.lr_scheduler import ReduceLROnPlateau
|
|
from torch.utils.data import DataLoader, Dataset
|
|
from transformers.modeling_outputs import ModelOutput
|
|
|
|
import wandb
|
|
from evaluation.evaluate import evaluate, log_eval
|
|
|
|
PRINT_ON_EPOCH = 1
|
|
|
|
|
|
def _normalize(lX, l2=True):
|
|
return {lang: normalize(np.asarray(X)) for lang, X in lX.items()} if l2 else lX
|
|
|
|
|
|
def verbosity_eval(epoch, print_eval):
|
|
if (epoch + 1) % print_eval == 0 and epoch != 0:
|
|
return True
|
|
else:
|
|
return False
|
|
|
|
|
|
def format_langkey_wandb(lang_dict, vgf_name):
|
|
log_dict = {}
|
|
for metric, l_dict in lang_dict.items():
|
|
for lang, value in l_dict.items():
|
|
log_dict[f"{vgf_name}/language metric/{metric}/{lang}"] = value
|
|
return log_dict
|
|
|
|
|
|
def format_average_wandb(avg_dict, vgf_name):
|
|
log_dict = {}
|
|
for metric, value in avg_dict.items():
|
|
log_dict[f"{vgf_name}/average metric/{metric}"] = value
|
|
return log_dict
|
|
|
|
|
|
def XdotM(X, M, sif):
|
|
E = X.dot(M)
|
|
if sif:
|
|
E = remove_pc(E, npc=1)
|
|
return E
|
|
|
|
|
|
def remove_pc(X, npc=1):
|
|
"""
|
|
Remove the projection on the principal components
|
|
:param X: X[i,:] is a data point
|
|
:param npc: number of principal components to remove
|
|
:return: XX[i, :] is the data point after removing its projection
|
|
"""
|
|
pc = compute_pc(X, npc)
|
|
if npc == 1:
|
|
XX = X - X.dot(pc.transpose()) * pc
|
|
else:
|
|
XX = X - X.dot(pc.transpose()).dot(pc)
|
|
return XX
|
|
|
|
|
|
def compute_pc(X, npc=1):
|
|
"""
|
|
Compute the principal components.
|
|
:param X: X[i,:] is a data point
|
|
:param npc: number of principal components to remove
|
|
:return: component_[i,:] is the i-th pc
|
|
"""
|
|
if isinstance(X, np.matrix):
|
|
X = np.asarray(X)
|
|
svd = TruncatedSVD(n_components=npc, n_iter=7, random_state=0)
|
|
svd.fit(X)
|
|
return svd.components_
|
|
|
|
|
|
def predict(logits, clf_type="multilabel"):
|
|
"""
|
|
Converts soft precictions to hard predictions [0,1]
|
|
"""
|
|
if clf_type == "multilabel":
|
|
prediction = torch.sigmoid(logits) > 0.5
|
|
return prediction.detach().cpu().numpy()
|
|
elif clf_type == "singlelabel":
|
|
if type(logits) != torch.Tensor:
|
|
logits = torch.tensor(logits)
|
|
prediction = torch.softmax(logits, dim=1)
|
|
prediction = prediction.detach().cpu().numpy()
|
|
_argmaxs = prediction.argmax(axis=1)
|
|
prediction = np.eye(prediction.shape[1])[_argmaxs]
|
|
return prediction
|
|
else:
|
|
raise NotImplementedError()
|
|
|
|
|
|
class TfidfVectorizerMultilingual:
|
|
def __init__(self, **kwargs):
|
|
self.kwargs = kwargs
|
|
|
|
def fit(self, lX, ly=None):
|
|
self.langs = sorted(lX.keys())
|
|
self.vectorizer = {
|
|
l: TfidfVectorizer(**self.kwargs).fit(lX[l]["text"]) for l in self.langs
|
|
}
|
|
return self
|
|
|
|
def transform(self, lX):
|
|
return {l: self.vectorizer[l].transform(lX[l]["text"]) for l in self.langs}
|
|
|
|
def fit_transform(self, lX, ly=None):
|
|
return self.fit(lX, ly).transform(lX)
|
|
|
|
def vocabulary(self, l=None):
|
|
if l is None:
|
|
return {l: self.vectorizer[l].vocabulary_ for l in self.langs}
|
|
else:
|
|
return self.vectorizer[l].vocabulary_
|
|
|
|
def get_analyzer(self, l=None):
|
|
if l is None:
|
|
return {l: self.vectorizer[l].build_analyzer() for l in self.langs}
|
|
else:
|
|
return self.vectorizer[l].build_analyzer()
|
|
|
|
|
|
class Trainer:
|
|
def __init__(
|
|
self,
|
|
model,
|
|
optimizer_name,
|
|
device,
|
|
loss_fn,
|
|
lr,
|
|
print_steps,
|
|
evaluate_step,
|
|
patience,
|
|
experiment_name,
|
|
checkpoint_path,
|
|
classification_type,
|
|
vgf_name,
|
|
n_jobs,
|
|
scheduler_name=None,
|
|
):
|
|
self.device = device
|
|
self.model = model.to(device)
|
|
self.optimizer, self.scheduler = self.init_optimizer(
|
|
optimizer_name, lr, scheduler_name
|
|
)
|
|
self.evaluate_steps = evaluate_step
|
|
self.loss_fn = loss_fn.to(device)
|
|
self.print_steps = print_steps
|
|
self.experiment_name = experiment_name
|
|
self.patience = patience
|
|
self.print_eval = 10
|
|
self.earlystopping = EarlyStopping(
|
|
patience=patience,
|
|
checkpoint_path=checkpoint_path,
|
|
verbose=False,
|
|
experiment_name=experiment_name,
|
|
)
|
|
self.clf_type = classification_type
|
|
self.vgf_name = vgf_name
|
|
self.scheduler_name = scheduler_name
|
|
self.n_jobs = n_jobs
|
|
self.monitored_metric = (
|
|
"macro-F1" if self.clf_type == "multilabel" else "accuracy"
|
|
) # TODO: make this configurable
|
|
|
|
def init_optimizer(self, optimizer_name, lr, scheduler_name):
|
|
if optimizer_name.lower() == "adamw":
|
|
optim = AdamW(self.model.parameters(), lr=lr)
|
|
else:
|
|
raise ValueError(f"Optimizer {optimizer_name} not supported")
|
|
if scheduler_name is None:
|
|
scheduler = None
|
|
elif scheduler_name == "ReduceLROnPlateau":
|
|
scheduler = ReduceLROnPlateau(optim, "max", factor=0.5, min_lr=1e-5)
|
|
else:
|
|
raise ValueError(f"Scheduler {scheduler_name} not supported")
|
|
return optim, scheduler
|
|
|
|
def get_config(self, train_dataloader, eval_dataloader, epochs):
|
|
return {
|
|
"model name": self.model.name_or_path
|
|
if not hasattr(self.model, "mt5encoder")
|
|
else self.model.mt5encoder.name_or_path,
|
|
"epochs": epochs,
|
|
"learning rate": self.optimizer.defaults["lr"],
|
|
"scheduler": self.scheduler_name, # TODO: add scheduler params
|
|
"train size": len(train_dataloader.dataset),
|
|
"eval size": len(eval_dataloader.dataset),
|
|
"train batch size": train_dataloader.batch_size,
|
|
"eval batch size": eval_dataloader.batch_size,
|
|
"max len": train_dataloader.dataset.X.shape[-1],
|
|
"patience": self.earlystopping.patience,
|
|
"evaluate every": self.evaluate_steps,
|
|
"print eval every": self.print_eval,
|
|
"print train steps": self.print_steps,
|
|
"classification type": self.clf_type,
|
|
}
|
|
|
|
def train(self, train_dataloader, eval_dataloader, epochs=10):
|
|
_config = self.get_config(train_dataloader, eval_dataloader, epochs)
|
|
|
|
print(f"- Training params for {self.experiment_name}:")
|
|
for k, v in _config.items():
|
|
print(f"\t{k}: {v}")
|
|
|
|
for epoch in range(epochs):
|
|
train_loss = self.train_epoch(train_dataloader, epoch)
|
|
|
|
if (epoch + 1) % self.evaluate_steps == 0 or (epoch + 1) == 1:
|
|
print_eval = verbosity_eval(epoch, self.print_eval)
|
|
with torch.no_grad():
|
|
eval_loss, avg_metrics, lang_metrics = self.evaluate(
|
|
eval_dataloader,
|
|
print_eval=print_eval,
|
|
n_jobs=self.n_jobs,
|
|
)
|
|
|
|
wandb.log(
|
|
{
|
|
f"{self.vgf_name}/loss/val": eval_loss,
|
|
**format_langkey_wandb(lang_metrics, self.vgf_name),
|
|
**format_average_wandb(avg_metrics, self.vgf_name),
|
|
},
|
|
commit=False,
|
|
)
|
|
|
|
stop = self.earlystopping(
|
|
avg_metrics[self.monitored_metric], self.model, epoch + 1
|
|
)
|
|
if stop:
|
|
print(
|
|
f"- restoring best model from epoch {self.earlystopping.best_epoch} with best metric: {self.earlystopping.best_score:3f}"
|
|
)
|
|
restored_model = self.earlystopping.load_model(self.model)
|
|
|
|
# swapping model on gpu
|
|
del self.model
|
|
self.model = restored_model.to(self.device)
|
|
break
|
|
|
|
if self.scheduler is not None:
|
|
self.scheduler.step(avg_metrics[self.monitored_metric])
|
|
|
|
wandb.log(
|
|
{
|
|
f"{self.vgf_name}/loss/train": train_loss,
|
|
f"{self.vgf_name}/learning rate": self.optimizer.param_groups[0][
|
|
"lr"
|
|
],
|
|
}
|
|
)
|
|
|
|
print(f"- last swipe on eval set")
|
|
self.train_epoch(
|
|
DataLoader(
|
|
eval_dataloader.dataset,
|
|
batch_size=train_dataloader.batch_size,
|
|
shuffle=True,
|
|
),
|
|
epoch=-1,
|
|
)
|
|
self.earlystopping.save_model(self.model)
|
|
return self.model
|
|
|
|
def train_epoch(self, dataloader, epoch):
|
|
self.model.train()
|
|
batch_losses = []
|
|
for b_idx, (x, y, lang) in enumerate(dataloader):
|
|
self.optimizer.zero_grad()
|
|
y_hat = self.model(x.to(self.device))
|
|
if isinstance(y_hat, ModelOutput):
|
|
loss = self.loss_fn(y_hat.logits, y.to(self.device))
|
|
else:
|
|
loss = self.loss_fn(y_hat, y.to(self.device))
|
|
loss.backward()
|
|
self.optimizer.step()
|
|
batch_losses.append(loss.item())
|
|
if (epoch + 1) % PRINT_ON_EPOCH == 0:
|
|
if ((b_idx + 1) % self.print_steps == 0) or b_idx == 0:
|
|
print(
|
|
f"Epoch: {epoch+1} Step: {b_idx+1} Loss: {np.mean(batch_losses):.4f}"
|
|
)
|
|
return np.mean(batch_losses)
|
|
|
|
def evaluate(self, dataloader, print_eval=True, n_jobs=-1):
|
|
self.model.eval()
|
|
eval_losses = []
|
|
|
|
lY_true = defaultdict(list)
|
|
lY_pred = defaultdict(list)
|
|
|
|
for b_idx, (x, y, lang) in enumerate(dataloader):
|
|
y_pred = self.model(x.to(self.device))
|
|
if isinstance(y_pred, ModelOutput):
|
|
loss = self.loss_fn(y_pred.logits, y.to(self.device))
|
|
predictions = predict(y_pred.logits, clf_type=self.clf_type)
|
|
else:
|
|
loss = self.loss_fn(y_pred, y.to(self.device))
|
|
predictions = predict(y_pred, clf_type=self.clf_type)
|
|
|
|
eval_losses.append(loss.item())
|
|
|
|
for l, _true, _pred in zip(lang, y, predictions):
|
|
lY_true[l].append(_true.detach().cpu().numpy())
|
|
lY_pred[l].append(_pred)
|
|
|
|
for lang in lY_true:
|
|
lY_true[lang] = np.vstack(lY_true[lang])
|
|
lY_pred[lang] = np.vstack(lY_pred[lang])
|
|
|
|
l_eval = evaluate(lY_true, lY_pred, clf_type=self.clf_type, n_jobs=n_jobs)
|
|
|
|
avg_metrics, lang_metrics = log_eval(
|
|
l_eval, phase="validation", clf_type=self.clf_type, verbose=print_eval
|
|
)
|
|
|
|
return np.mean(eval_losses), avg_metrics, lang_metrics
|
|
|
|
|
|
class EarlyStopping:
|
|
def __init__(
|
|
self,
|
|
patience,
|
|
checkpoint_path,
|
|
experiment_name,
|
|
min_delta=0,
|
|
verbose=True,
|
|
):
|
|
self.patience = patience
|
|
self.min_delta = min_delta
|
|
self.counter = 0
|
|
self.best_score = 0
|
|
self.best_epoch = None
|
|
self.verbose = verbose
|
|
self.checkpoint_path = checkpoint_path
|
|
self.experiment_name = experiment_name
|
|
|
|
def __call__(self, validation, model, epoch):
|
|
if validation >= self.best_score:
|
|
wandb.log({"patience": self.patience - self.counter})
|
|
if self.verbose:
|
|
print(
|
|
f"- earlystopping: Validation score improved from {self.best_score:.3f} to {validation:.3f}"
|
|
)
|
|
self.best_score = validation
|
|
self.counter = 0
|
|
self.best_epoch = epoch
|
|
# print(f"- earlystopping: Saving best model from epoch {epoch}")
|
|
self.save_model(model)
|
|
elif validation < (self.best_score + self.min_delta):
|
|
self.counter += 1
|
|
wandb.log({"patience": self.patience - self.counter})
|
|
if self.verbose:
|
|
print(
|
|
f"- earlystopping: Validation score decreased from {self.best_score:.3f} to {validation:.3f}, current patience: {self.patience - self.counter}"
|
|
)
|
|
if self.counter >= self.patience and self.patience != -1:
|
|
print(f"- earlystopping: Early stopping at epoch {epoch}")
|
|
return True
|
|
|
|
def save_model(self, model):
|
|
os.makedirs(self.checkpoint_path, exist_ok=True)
|
|
_checkpoint_dir = os.path.join(self.checkpoint_path, self.experiment_name)
|
|
model.save_pretrained(_checkpoint_dir)
|
|
|
|
def load_model(self, model):
|
|
_checkpoint_dir = os.path.join(self.checkpoint_path, self.experiment_name)
|
|
return model.from_pretrained(_checkpoint_dir)
|
|
|
|
|
|
class AttentionModule(nn.Module):
|
|
def __init__(self, embed_dim, num_heads, h_dim, out_dim, aggfunc_type):
|
|
"""We are calling sigmoid on the evaluation loop (Trainer.evaluate), so we
|
|
are not applying explicitly here at training time. However, we should
|
|
explcitly squash outputs through the sigmoid at inference (self.transform) (???)
|
|
"""
|
|
super().__init__()
|
|
self.aggfunc = aggfunc_type
|
|
self.attn = nn.MultiheadAttention(embed_dim, num_heads, dropout=0.1)
|
|
# self.layer_norm = nn.LayerNorm(embed_dim)
|
|
if self.aggfunc == "concat":
|
|
self.linear = nn.Linear(embed_dim, out_dim)
|
|
self.sigmoid = nn.Sigmoid()
|
|
|
|
def init_weights(self, mode="mean"):
|
|
# TODO: add init function of the attention module: either all weights are positive or set to 1/num_classes
|
|
raise NotImplementedError
|
|
|
|
def __call__(self, X):
|
|
out, attn_weights = self.attn(query=X, key=X, value=X)
|
|
# out = self.layer_norm(out)
|
|
if self.aggfunc == "concat":
|
|
out = self.linear(out)
|
|
# out = self.sigmoid(out)
|
|
return out
|
|
|
|
def transform(self, X):
|
|
"""explicitly calling sigmoid at inference time"""
|
|
out, attn_weights = self.attn(query=X, key=X, value=X)
|
|
out = self.sigmoid(out)
|
|
return out
|
|
|
|
def save_pretrained(self, path):
|
|
torch.save(self, f"{path}.pt")
|
|
|
|
def from_pretrained(self, path):
|
|
return torch.load(f"{path}.pt")
|
|
|
|
|
|
class AttentionAggregator:
|
|
def __init__(
|
|
self,
|
|
embed_dim,
|
|
out_dim,
|
|
epochs,
|
|
lr,
|
|
patience,
|
|
attn_stacking_type,
|
|
h_dim=512,
|
|
num_heads=1,
|
|
device="cpu",
|
|
):
|
|
self.embed_dim = embed_dim
|
|
self.h_dim = h_dim
|
|
self.out_dim = out_dim
|
|
self.patience = patience
|
|
self.num_heads = num_heads
|
|
self.device = device
|
|
self.epochs = epochs
|
|
self.lr = lr
|
|
self.stacking_type = attn_stacking_type
|
|
self.tr_batch_size = 512
|
|
self.eval_batch_size = 1024
|
|
self.attn = AttentionModule(
|
|
self.embed_dim,
|
|
self.num_heads,
|
|
self.h_dim,
|
|
self.out_dim,
|
|
aggfunc_type=self.stacking_type,
|
|
).to(self.device)
|
|
|
|
def fit(self, X, Y):
|
|
print("- fitting Attention-based aggregating function")
|
|
hstacked_X = self.stack(X)
|
|
|
|
tr_lX, tr_lY, val_lX, val_lY = self.get_train_val_data(
|
|
hstacked_X, Y, split=0.2, seed=42
|
|
)
|
|
|
|
tra_dataloader = DataLoader(
|
|
AggregatorDatasetTorch(tr_lX, tr_lY, split="train"),
|
|
batch_size=self.tr_batch_size,
|
|
shuffle=True,
|
|
)
|
|
eval_dataloader = DataLoader(
|
|
AggregatorDatasetTorch(val_lX, val_lY, split="eval"),
|
|
batch_size=self.eval_batch_size,
|
|
shuffle=False,
|
|
)
|
|
|
|
experiment_name = "attention_aggregator"
|
|
trainer = Trainer(
|
|
self.attn,
|
|
optimizer_name="adamW",
|
|
lr=self.lr,
|
|
loss_fn=torch.nn.CrossEntropyLoss(),
|
|
print_steps=25,
|
|
evaluate_step=10,
|
|
patience=self.patience,
|
|
experiment_name=experiment_name,
|
|
device=self.device,
|
|
checkpoint_path="models/aggregator",
|
|
)
|
|
|
|
trainer.train(
|
|
train_dataloader=tra_dataloader,
|
|
eval_dataloader=eval_dataloader,
|
|
epochs=self.epochs,
|
|
)
|
|
return self
|
|
|
|
def transform(self, X):
|
|
hstacked_X = self.stack(X)
|
|
dataset = AggregatorDatasetTorch(hstacked_X, lY=None, split="whole")
|
|
dataloader = DataLoader(dataset, batch_size=32, shuffle=False)
|
|
|
|
_embeds = []
|
|
l_embeds = defaultdict(list)
|
|
|
|
self.attn.eval()
|
|
with torch.no_grad():
|
|
for input_ids, lang in dataloader:
|
|
input_ids = input_ids.to(self.device)
|
|
out = self.attn.transform(input_ids)
|
|
_embeds.append((out.cpu().numpy(), lang))
|
|
|
|
for embed, lang in _embeds:
|
|
for sample_embed, sample_lang in zip(embed, lang):
|
|
l_embeds[sample_lang].append(sample_embed)
|
|
|
|
l_embeds = {lang: np.array(preds) for lang, preds in l_embeds.items()}
|
|
|
|
return l_embeds
|
|
|
|
def stack(self, data):
|
|
if self.stacking_type == "concat":
|
|
hstack = self._concat_stack(data)
|
|
elif self.stacking_type == "mean":
|
|
hstack = self._mean_stack(data)
|
|
return hstack
|
|
|
|
def _concat_stack(self, data):
|
|
_langs = data[0].keys()
|
|
l_projections = {}
|
|
for l in _langs:
|
|
l_projections[l] = torch.tensor(
|
|
np.hstack([view[l] for view in data]), dtype=torch.float32
|
|
)
|
|
return l_projections
|
|
|
|
def _mean_stack(self, data):
|
|
# TODO: double check this mess
|
|
aggregated = {lang: torch.zeros(d.shape) for lang, d in data[0].items()}
|
|
for lang_projections in data:
|
|
for lang, projection in lang_projections.items():
|
|
aggregated[lang] += projection
|
|
|
|
for lang, projection in aggregated.items():
|
|
aggregated[lang] = (aggregated[lang] / len(data)).float()
|
|
|
|
return aggregated
|
|
|
|
def get_train_val_data(self, lX, lY, split=0.2, seed=42):
|
|
tr_lX, tr_lY, val_lX, val_lY = {}, {}, {}, {}
|
|
|
|
for lang in lX.keys():
|
|
tr_X, val_X, tr_Y, val_Y = train_test_split(
|
|
lX[lang], lY[lang], test_size=split, random_state=seed, shuffle=False
|
|
)
|
|
tr_lX[lang] = tr_X
|
|
tr_lY[lang] = tr_Y
|
|
val_lX[lang] = val_X
|
|
val_lY[lang] = val_Y
|
|
|
|
return tr_lX, tr_lY, val_lX, val_lY
|
|
|
|
|
|
class AggregatorDatasetTorch(Dataset):
|
|
def __init__(self, lX, lY, split="train"):
|
|
self.lX = lX
|
|
self.lY = lY
|
|
self.split = split
|
|
self.langs = []
|
|
self.init()
|
|
|
|
def init(self):
|
|
self.X = torch.vstack([data for data in self.lX.values()])
|
|
if self.split != "whole":
|
|
self.Y = torch.vstack([torch.Tensor(data) for data in self.lY.values()])
|
|
self.langs = sum(
|
|
[
|
|
v
|
|
for v in {
|
|
lang: [lang] * len(data) for lang, data in self.lX.items()
|
|
}.values()
|
|
],
|
|
[],
|
|
)
|
|
|
|
return self
|
|
|
|
def __len__(self):
|
|
return len(self.X)
|
|
|
|
def __getitem__(self, index):
|
|
if self.split == "whole":
|
|
return self.X[index], self.langs[index]
|
|
return self.X[index], self.Y[index], self.langs[index]
|