gfun_multimodal/gfun/vgfs/commons.py

589 lines
19 KiB
Python

import os
from collections import defaultdict
import numpy as np
import torch
import torch.nn as nn
from sklearn.decomposition import TruncatedSVD
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import normalize
from torch.optim import AdamW
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import DataLoader, Dataset
from transformers.modeling_outputs import ModelOutput
import wandb
from evaluation.evaluate import evaluate, log_eval
PRINT_ON_EPOCH = 1
def _normalize(lX, l2=True):
return {lang: normalize(np.asarray(X)) for lang, X in lX.items()} if l2 else lX
def verbosity_eval(epoch, print_eval):
if (epoch + 1) % print_eval == 0 and epoch != 0:
return True
else:
return False
def format_langkey_wandb(lang_dict, vgf_name):
log_dict = {}
for metric, l_dict in lang_dict.items():
for lang, value in l_dict.items():
log_dict[f"{vgf_name}/language metric/{metric}/{lang}"] = value
return log_dict
def format_average_wandb(avg_dict, vgf_name):
log_dict = {}
for metric, value in avg_dict.items():
log_dict[f"{vgf_name}/average metric/{metric}"] = value
return log_dict
def XdotM(X, M, sif):
E = X.dot(M)
if sif:
E = remove_pc(E, npc=1)
return E
def remove_pc(X, npc=1):
"""
Remove the projection on the principal components
:param X: X[i,:] is a data point
:param npc: number of principal components to remove
:return: XX[i, :] is the data point after removing its projection
"""
pc = compute_pc(X, npc)
if npc == 1:
XX = X - X.dot(pc.transpose()) * pc
else:
XX = X - X.dot(pc.transpose()).dot(pc)
return XX
def compute_pc(X, npc=1):
"""
Compute the principal components.
:param X: X[i,:] is a data point
:param npc: number of principal components to remove
:return: component_[i,:] is the i-th pc
"""
if isinstance(X, np.matrix):
X = np.asarray(X)
svd = TruncatedSVD(n_components=npc, n_iter=7, random_state=0)
svd.fit(X)
return svd.components_
def predict(logits, clf_type="multilabel"):
"""
Converts soft precictions to hard predictions [0,1]
"""
if clf_type == "multilabel":
prediction = torch.sigmoid(logits) > 0.5
return prediction.detach().cpu().numpy()
elif clf_type == "singlelabel":
if type(logits) != torch.Tensor:
logits = torch.tensor(logits)
prediction = torch.softmax(logits, dim=1)
prediction = prediction.detach().cpu().numpy()
_argmaxs = prediction.argmax(axis=1)
prediction = np.eye(prediction.shape[1])[_argmaxs]
return prediction
else:
raise NotImplementedError()
class TfidfVectorizerMultilingual:
def __init__(self, **kwargs):
self.kwargs = kwargs
def fit(self, lX, ly=None):
self.langs = sorted(lX.keys())
self.vectorizer = {
l: TfidfVectorizer(**self.kwargs).fit(lX[l]["text"]) for l in self.langs
}
return self
def transform(self, lX):
return {l: self.vectorizer[l].transform(lX[l]["text"]) for l in self.langs}
def fit_transform(self, lX, ly=None):
return self.fit(lX, ly).transform(lX)
def vocabulary(self, l=None):
if l is None:
return {l: self.vectorizer[l].vocabulary_ for l in self.langs}
else:
return self.vectorizer[l].vocabulary_
def get_analyzer(self, l=None):
if l is None:
return {l: self.vectorizer[l].build_analyzer() for l in self.langs}
else:
return self.vectorizer[l].build_analyzer()
class Trainer:
def __init__(
self,
model,
optimizer_name,
device,
loss_fn,
lr,
print_steps,
evaluate_step,
patience,
experiment_name,
checkpoint_path,
classification_type,
vgf_name,
n_jobs,
scheduler_name=None,
):
self.device = device
self.model = model.to(device)
self.optimizer, self.scheduler = self.init_optimizer(
optimizer_name, lr, scheduler_name
)
self.evaluate_steps = evaluate_step
self.loss_fn = loss_fn.to(device)
self.print_steps = print_steps
self.experiment_name = experiment_name
self.patience = patience
self.print_eval = 10
self.earlystopping = EarlyStopping(
patience=patience,
checkpoint_path=checkpoint_path,
verbose=False,
experiment_name=experiment_name,
)
self.clf_type = classification_type
self.vgf_name = vgf_name
self.scheduler_name = scheduler_name
self.n_jobs = n_jobs
self.monitored_metric = (
"macro-F1" if self.clf_type == "multilabel" else "accuracy"
) # TODO: make this configurable
def init_optimizer(self, optimizer_name, lr, scheduler_name):
if optimizer_name.lower() == "adamw":
optim = AdamW(self.model.parameters(), lr=lr)
else:
raise ValueError(f"Optimizer {optimizer_name} not supported")
if scheduler_name is None:
scheduler = None
elif scheduler_name == "ReduceLROnPlateau":
scheduler = ReduceLROnPlateau(optim, "max", factor=0.5, min_lr=1e-5)
else:
raise ValueError(f"Scheduler {scheduler_name} not supported")
return optim, scheduler
def get_config(self, train_dataloader, eval_dataloader, epochs):
return {
"model name": self.model.name_or_path
if not hasattr(self.model, "mt5encoder")
else self.model.mt5encoder.name_or_path,
"epochs": epochs,
"learning rate": self.optimizer.defaults["lr"],
"scheduler": self.scheduler_name, # TODO: add scheduler params
"train size": len(train_dataloader.dataset),
"eval size": len(eval_dataloader.dataset),
"train batch size": train_dataloader.batch_size,
"eval batch size": eval_dataloader.batch_size,
"max len": train_dataloader.dataset.X.shape[-1],
"patience": self.earlystopping.patience,
"evaluate every": self.evaluate_steps,
"print eval every": self.print_eval,
"print train steps": self.print_steps,
"classification type": self.clf_type,
}
def train(self, train_dataloader, eval_dataloader, epochs=10):
_config = self.get_config(train_dataloader, eval_dataloader, epochs)
print(f"- Training params for {self.experiment_name}:")
for k, v in _config.items():
print(f"\t{k}: {v}")
for epoch in range(epochs):
train_loss = self.train_epoch(train_dataloader, epoch)
if (epoch + 1) % self.evaluate_steps == 0 or (epoch + 1) == 1:
print_eval = verbosity_eval(epoch, self.print_eval)
with torch.no_grad():
eval_loss, avg_metrics, lang_metrics = self.evaluate(
eval_dataloader,
print_eval=print_eval,
n_jobs=self.n_jobs,
)
wandb.log(
{
f"{self.vgf_name}/loss/val": eval_loss,
**format_langkey_wandb(lang_metrics, self.vgf_name),
**format_average_wandb(avg_metrics, self.vgf_name),
},
commit=False,
)
stop = self.earlystopping(
avg_metrics[self.monitored_metric], self.model, epoch + 1
)
if stop:
print(
f"- restoring best model from epoch {self.earlystopping.best_epoch} with best metric: {self.earlystopping.best_score:3f}"
)
restored_model = self.earlystopping.load_model(self.model)
# swapping model on gpu
del self.model
self.model = restored_model.to(self.device)
break
if self.scheduler is not None:
self.scheduler.step(avg_metrics[self.monitored_metric])
wandb.log(
{
f"{self.vgf_name}/loss/train": train_loss,
f"{self.vgf_name}/learning rate": self.optimizer.param_groups[0][
"lr"
],
}
)
print(f"- last swipe on eval set")
self.train_epoch(
DataLoader(
eval_dataloader.dataset,
batch_size=train_dataloader.batch_size,
shuffle=True,
),
epoch=-1,
)
self.earlystopping.save_model(self.model)
return self.model
def train_epoch(self, dataloader, epoch):
self.model.train()
batch_losses = []
for b_idx, (x, y, lang) in enumerate(dataloader):
self.optimizer.zero_grad()
y_hat = self.model(x.to(self.device))
if isinstance(y_hat, ModelOutput):
loss = self.loss_fn(y_hat.logits, y.to(self.device))
else:
loss = self.loss_fn(y_hat, y.to(self.device))
loss.backward()
self.optimizer.step()
batch_losses.append(loss.item())
if (epoch + 1) % PRINT_ON_EPOCH == 0:
if ((b_idx + 1) % self.print_steps == 0) or b_idx == 0:
print(
f"Epoch: {epoch+1} Step: {b_idx+1} Loss: {np.mean(batch_losses):.4f}"
)
return np.mean(batch_losses)
def evaluate(self, dataloader, print_eval=True, n_jobs=-1):
self.model.eval()
eval_losses = []
lY_true = defaultdict(list)
lY_pred = defaultdict(list)
for b_idx, (x, y, lang) in enumerate(dataloader):
y_pred = self.model(x.to(self.device))
if isinstance(y_pred, ModelOutput):
loss = self.loss_fn(y_pred.logits, y.to(self.device))
predictions = predict(y_pred.logits, clf_type=self.clf_type)
else:
loss = self.loss_fn(y_pred, y.to(self.device))
predictions = predict(y_pred, clf_type=self.clf_type)
eval_losses.append(loss.item())
for l, _true, _pred in zip(lang, y, predictions):
lY_true[l].append(_true.detach().cpu().numpy())
lY_pred[l].append(_pred)
for lang in lY_true:
lY_true[lang] = np.vstack(lY_true[lang])
lY_pred[lang] = np.vstack(lY_pred[lang])
l_eval = evaluate(lY_true, lY_pred, clf_type=self.clf_type, n_jobs=n_jobs)
avg_metrics, lang_metrics = log_eval(
l_eval, phase="validation", clf_type=self.clf_type, verbose=print_eval
)
return np.mean(eval_losses), avg_metrics, lang_metrics
class EarlyStopping:
def __init__(
self,
patience,
checkpoint_path,
experiment_name,
min_delta=0,
verbose=True,
):
self.patience = patience
self.min_delta = min_delta
self.counter = 0
self.best_score = 0
self.best_epoch = None
self.verbose = verbose
self.checkpoint_path = checkpoint_path
self.experiment_name = experiment_name
def __call__(self, validation, model, epoch):
if validation >= self.best_score:
wandb.log({"patience": self.patience - self.counter})
if self.verbose:
print(
f"- earlystopping: Validation score improved from {self.best_score:.3f} to {validation:.3f}"
)
self.best_score = validation
self.counter = 0
self.best_epoch = epoch
# print(f"- earlystopping: Saving best model from epoch {epoch}")
self.save_model(model)
elif validation < (self.best_score + self.min_delta):
self.counter += 1
wandb.log({"patience": self.patience - self.counter})
if self.verbose:
print(
f"- earlystopping: Validation score decreased from {self.best_score:.3f} to {validation:.3f}, current patience: {self.patience - self.counter}"
)
if self.counter >= self.patience and self.patience != -1:
print(f"- earlystopping: Early stopping at epoch {epoch}")
return True
def save_model(self, model):
os.makedirs(self.checkpoint_path, exist_ok=True)
_checkpoint_dir = os.path.join(self.checkpoint_path, self.experiment_name)
model.save_pretrained(_checkpoint_dir)
def load_model(self, model):
_checkpoint_dir = os.path.join(self.checkpoint_path, self.experiment_name)
return model.from_pretrained(_checkpoint_dir)
class AttentionModule(nn.Module):
def __init__(self, embed_dim, num_heads, h_dim, out_dim, aggfunc_type):
"""We are calling sigmoid on the evaluation loop (Trainer.evaluate), so we
are not applying explicitly here at training time. However, we should
explcitly squash outputs through the sigmoid at inference (self.transform) (???)
"""
super().__init__()
self.aggfunc = aggfunc_type
self.attn = nn.MultiheadAttention(embed_dim, num_heads, dropout=0.1)
# self.layer_norm = nn.LayerNorm(embed_dim)
if self.aggfunc == "concat":
self.linear = nn.Linear(embed_dim, out_dim)
self.sigmoid = nn.Sigmoid()
def init_weights(self, mode="mean"):
# TODO: add init function of the attention module: either all weights are positive or set to 1/num_classes
raise NotImplementedError
def __call__(self, X):
out, attn_weights = self.attn(query=X, key=X, value=X)
# out = self.layer_norm(out)
if self.aggfunc == "concat":
out = self.linear(out)
# out = self.sigmoid(out)
return out
def transform(self, X):
"""explicitly calling sigmoid at inference time"""
out, attn_weights = self.attn(query=X, key=X, value=X)
out = self.sigmoid(out)
return out
def save_pretrained(self, path):
torch.save(self, f"{path}.pt")
def from_pretrained(self, path):
return torch.load(f"{path}.pt")
class AttentionAggregator:
def __init__(
self,
embed_dim,
out_dim,
epochs,
lr,
patience,
attn_stacking_type,
h_dim=512,
num_heads=1,
device="cpu",
):
self.embed_dim = embed_dim
self.h_dim = h_dim
self.out_dim = out_dim
self.patience = patience
self.num_heads = num_heads
self.device = device
self.epochs = epochs
self.lr = lr
self.stacking_type = attn_stacking_type
self.tr_batch_size = 512
self.eval_batch_size = 1024
self.attn = AttentionModule(
self.embed_dim,
self.num_heads,
self.h_dim,
self.out_dim,
aggfunc_type=self.stacking_type,
).to(self.device)
def fit(self, X, Y):
print("- fitting Attention-based aggregating function")
hstacked_X = self.stack(X)
tr_lX, tr_lY, val_lX, val_lY = self.get_train_val_data(
hstacked_X, Y, split=0.2, seed=42
)
tra_dataloader = DataLoader(
AggregatorDatasetTorch(tr_lX, tr_lY, split="train"),
batch_size=self.tr_batch_size,
shuffle=True,
)
eval_dataloader = DataLoader(
AggregatorDatasetTorch(val_lX, val_lY, split="eval"),
batch_size=self.eval_batch_size,
shuffle=False,
)
experiment_name = "attention_aggregator"
trainer = Trainer(
self.attn,
optimizer_name="adamW",
lr=self.lr,
loss_fn=torch.nn.CrossEntropyLoss(),
print_steps=25,
evaluate_step=10,
patience=self.patience,
experiment_name=experiment_name,
device=self.device,
checkpoint_path="models/aggregator",
)
trainer.train(
train_dataloader=tra_dataloader,
eval_dataloader=eval_dataloader,
epochs=self.epochs,
)
return self
def transform(self, X):
hstacked_X = self.stack(X)
dataset = AggregatorDatasetTorch(hstacked_X, lY=None, split="whole")
dataloader = DataLoader(dataset, batch_size=32, shuffle=False)
_embeds = []
l_embeds = defaultdict(list)
self.attn.eval()
with torch.no_grad():
for input_ids, lang in dataloader:
input_ids = input_ids.to(self.device)
out = self.attn.transform(input_ids)
_embeds.append((out.cpu().numpy(), lang))
for embed, lang in _embeds:
for sample_embed, sample_lang in zip(embed, lang):
l_embeds[sample_lang].append(sample_embed)
l_embeds = {lang: np.array(preds) for lang, preds in l_embeds.items()}
return l_embeds
def stack(self, data):
if self.stacking_type == "concat":
hstack = self._concat_stack(data)
elif self.stacking_type == "mean":
hstack = self._mean_stack(data)
return hstack
def _concat_stack(self, data):
_langs = data[0].keys()
l_projections = {}
for l in _langs:
l_projections[l] = torch.tensor(
np.hstack([view[l] for view in data]), dtype=torch.float32
)
return l_projections
def _mean_stack(self, data):
# TODO: double check this mess
aggregated = {lang: torch.zeros(d.shape) for lang, d in data[0].items()}
for lang_projections in data:
for lang, projection in lang_projections.items():
aggregated[lang] += projection
for lang, projection in aggregated.items():
aggregated[lang] = (aggregated[lang] / len(data)).float()
return aggregated
def get_train_val_data(self, lX, lY, split=0.2, seed=42):
tr_lX, tr_lY, val_lX, val_lY = {}, {}, {}, {}
for lang in lX.keys():
tr_X, val_X, tr_Y, val_Y = train_test_split(
lX[lang], lY[lang], test_size=split, random_state=seed, shuffle=False
)
tr_lX[lang] = tr_X
tr_lY[lang] = tr_Y
val_lX[lang] = val_X
val_lY[lang] = val_Y
return tr_lX, tr_lY, val_lX, val_lY
class AggregatorDatasetTorch(Dataset):
def __init__(self, lX, lY, split="train"):
self.lX = lX
self.lY = lY
self.split = split
self.langs = []
self.init()
def init(self):
self.X = torch.vstack([data for data in self.lX.values()])
if self.split != "whole":
self.Y = torch.vstack([torch.Tensor(data) for data in self.lY.values()])
self.langs = sum(
[
v
for v in {
lang: [lang] * len(data) for lang, data in self.lX.items()
}.values()
],
[],
)
return self
def __len__(self):
return len(self.X)
def __getitem__(self, index):
if self.split == "whole":
return self.X[index], self.langs[index]
return self.X[index], self.Y[index], self.langs[index]