attention-based aggregation function, first implementation, some hard-coded parameters

This commit is contained in:
Andrea Pedrotti 2023-02-10 18:29:58 +01:00
parent 2a42b21ac9
commit 13ada46c34
5 changed files with 184 additions and 19 deletions

View File

@ -6,7 +6,7 @@ sys.path.append(os.path.join(os.getcwd(), "gfun"))
import pickle
import numpy as np
from vgfs.commons import TfidfVectorizerMultilingual
from vgfs.commons import TfidfVectorizerMultilingual, AttentionAggregator
from vgfs.learners.svms import MetaClassifier, get_learner
from vgfs.multilingualGen import MultilingualGen
from gfun.vgfs.textualTransformerGen import TextualTransformerGen
@ -44,6 +44,7 @@ class GeneralizedFunnelling:
self.multilingual_vgf = multilingual
self.trasformer_vgf = transformer
self.probabilistic = probabilistic
self.num_labels = 73 # TODO: hard-coded
# ------------------------
self.langs = langs
self.embed_dir = embed_dir
@ -81,7 +82,6 @@ class GeneralizedFunnelling:
self.load_trained
)
# TODO: config like aggfunc, device, n_jobs, etc
return self
if self.posteriors_vgf:
fun = VanillaFunGen(
@ -121,6 +121,15 @@ class GeneralizedFunnelling:
)
self.first_tier_learners.append(transformer_vgf)
if self.aggfunc == "attn":
self.attn_aggregator = AttentionAggregator(
embed_dim=self.get_attn_agg_dim(),
out_dim=self.num_labels,
num_heads=1,
device=self.device,
epochs=self.epochs,
)
self.metaclassifier = MetaClassifier(
meta_learner=get_learner(calibrate=True, kernel="rbf"),
meta_parameters=get_params(self.optimc),
@ -160,7 +169,7 @@ class GeneralizedFunnelling:
l_posteriors = vgf.fit_transform(lX, lY)
projections.append(l_posteriors)
agg = self.aggregate(projections)
agg = self.aggregate(projections, lY)
self.metaclassifier.fit(agg, lY)
return self
@ -177,15 +186,27 @@ class GeneralizedFunnelling:
def fit_transform(self, lX, lY):
return self.fit(lX, lY).transform(lX)
def aggregate(self, first_tier_projections):
def aggregate(self, first_tier_projections, lY=None):
if self.aggfunc == "mean":
aggregated = self._aggregate_mean(first_tier_projections)
elif self.aggfunc == "concat":
aggregated = self._aggregate_concat(first_tier_projections)
elif self.aggfunc == "attn":
aggregated = self._aggregate_attn(first_tier_projections, lY)
else:
raise NotImplementedError
return aggregated
def _aggregate_attn(self, first_tier_projections, lY=None):
if lY is None:
# at prediction time
aggregated = self.attn_aggregator.transform(first_tier_projections)
else:
# at training time we must fit the attention layer
self.attn_aggregator.fit(first_tier_projections, lY)
aggregated = self.attn_aggregator.transform(first_tier_projections)
return aggregated
def _aggregate_concat(self, first_tier_projections):
aggregated = {}
for lang in self.langs:
@ -201,7 +222,6 @@ class GeneralizedFunnelling:
for lang, projection in lang_projections.items():
aggregated[lang] += projection
# Computing mean
for lang, projection in aggregated.items():
aggregated[lang] /= len(first_tier_projections)
@ -281,6 +301,11 @@ class GeneralizedFunnelling:
vectorizer = pickle.load(f)
return first_tier_learners, metaclassifier, vectorizer
def get_attn_agg_dim(self):
# TODO: hardcoded for now
print("\n[NB: ATTN AGGREGATOR DIM HARD-CODED TO 146]\n")
return 146
def get_params(optimc=False):
if not optimc:

View File

@ -3,13 +3,18 @@ from collections import defaultdict
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from sklearn.decomposition import TruncatedSVD
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.preprocessing import normalize
from torch.optim import AdamW
from transformers.modeling_outputs import SequenceClassifierOutput
from evaluation.evaluate import evaluate, log_eval
PRINT_ON_EPOCH = 10
def _normalize(lX, l2=True):
return {lang: normalize(np.asarray(X)) for lang, X in lX.items()} if l2 else lX
@ -107,6 +112,7 @@ class Trainer:
evaluate_step,
patience,
experiment_name,
checkpoint_path,
):
self.device = device
self.model = model.to(device)
@ -118,7 +124,7 @@ class Trainer:
self.patience = patience
self.earlystopping = EarlyStopping(
patience=patience,
checkpoint_path="models/vgfs/transformer/",
checkpoint_path=checkpoint_path,
verbose=True,
experiment_name=experiment_name,
)
@ -163,9 +169,13 @@ class Trainer:
for b_idx, (x, y, lang) in enumerate(dataloader):
self.optimizer.zero_grad()
y_hat = self.model(x.to(self.device))
if isinstance(y_hat, SequenceClassifierOutput):
loss = self.loss_fn(y_hat.logits, y.to(self.device))
else:
loss = self.loss_fn(y_hat, y.to(self.device))
loss.backward()
self.optimizer.step()
if (epoch + 1) % PRINT_ON_EPOCH == 0:
if b_idx % self.print_steps == 0:
print(f"Epoch: {epoch+1} Step: {b_idx+1} Loss: {loss:.4f}")
return self
@ -178,8 +188,12 @@ class Trainer:
for b_idx, (x, y, lang) in enumerate(dataloader):
y_hat = self.model(x.to(self.device))
if isinstance(y_hat, SequenceClassifierOutput):
loss = self.loss_fn(y_hat.logits, y.to(self.device))
predictions = predict(y_hat.logits, classification_type="multilabel")
else:
loss = self.loss_fn(y_hat, y.to(self.device))
predictions = predict(y_hat, classification_type="multilabel")
for l, _true, _pred in zip(lang, y, predictions):
lY[l].append(_true.detach().cpu().numpy())
@ -240,3 +254,135 @@ class EarlyStopping:
def load_model(self, model):
_checkpoint_dir = os.path.join(self.checkpoint_path, self.experiment_name)
return model.from_pretrained(_checkpoint_dir)
class AttentionModule(nn.Module):
def __init__(self, embed_dim, num_heads, out_dim):
super().__init__()
self.attn = nn.MultiheadAttention(embed_dim, num_heads)
self.linear = nn.Linear(embed_dim, out_dim)
def __call__(self, X):
attn_out, attn_weights = self.attn(query=X, key=X, value=X)
out = self.linear(attn_out)
return out
def transform(self, X):
attn_out, attn_weights = self.attn(query=X, key=X, value=X)
return attn_out
def save_pretrained(self, path):
torch.save(self.state_dict(), f"{path}.pt")
def _wtf(self):
print("wtf")
class AttentionAggregator:
def __init__(self, embed_dim, out_dim, epochs, num_heads=1, device="cpu"):
self.embed_dim = embed_dim
self.num_heads = num_heads
self.device = device
self.epochs = epochs
self.attn = AttentionModule(embed_dim, num_heads, out_dim).to(self.device)
def fit(self, X, Y):
print("- fitting Attention-based aggregating function")
hstacked_X = self.stack(X)
dataset = AggregatorDatasetTorch(hstacked_X, Y)
tra_dataloader = DataLoader(dataset, batch_size=32, shuffle=True)
experiment_name = "attention_aggregator"
trainer = Trainer(
self.attn,
optimizer_name="adamW",
lr=1e-3,
loss_fn=torch.nn.CrossEntropyLoss(),
print_steps=100,
evaluate_step=1000,
patience=10,
experiment_name=experiment_name,
device=self.device,
checkpoint_path="models/aggregator",
)
trainer.train(
train_dataloader=tra_dataloader,
eval_dataloader=tra_dataloader,
epochs=self.epochs,
)
return self
def transform(self, X):
# TODO: implement transform
h_stacked = self.stack(X)
dataset = AggregatorDatasetTorch(h_stacked, lY=None, split="whole")
dataloader = DataLoader(dataset, batch_size=32, shuffle=False)
_embeds = []
l_embeds = defaultdict(list)
self.attn.eval()
with torch.no_grad():
for input_ids, lang in dataloader:
input_ids = input_ids.to(self.device)
out = self.attn.transform(input_ids)
_embeds.append((out.cpu().numpy(), lang))
for embed, lang in _embeds:
for sample_embed, sample_lang in zip(embed, lang):
l_embeds[sample_lang].append(sample_embed)
l_embeds = {lang: np.array(preds) for lang, preds in l_embeds.items()}
return l_embeds
def stack(self, data):
hstack = self._hstack(data)
return hstack
def _hstack(self, data):
_langs = data[0].keys()
l_projections = {}
for l in _langs:
l_projections[l] = torch.tensor(
np.hstack([view[l] for view in data]), dtype=torch.float32
)
return l_projections
def _vstack(self, data):
return torch.vstack()
class AggregatorDatasetTorch(Dataset):
def __init__(self, lX, lY, split="train"):
self.lX = lX
self.lY = lY
self.split = split
self.langs = []
self.init()
def init(self):
self.X = torch.vstack([data for data in self.lX.values()])
if self.split != "whole":
self.Y = torch.vstack([torch.Tensor(data) for data in self.lY.values()])
self.langs = sum(
[
v
for v in {
lang: [lang] * len(data) for lang, data in self.lX.items()
}.values()
],
[],
)
return self
def __len__(self):
return len(self.X)
def __getitem__(self, index):
if self.split == "whole":
return self.X[index], self.langs[index]
return self.X[index], self.Y[index], self.langs[index]

View File

@ -241,14 +241,6 @@ class MetaClassifier:
else:
return Z
# def stack(self, lZ, lY=None):
# X_stacked = np.vstack(list(lZ.values()))
# if lY is not None:
# Y_stacked = np.vstack(list(lY.values()))
# return X_stacked, Y_stacked
# else:
# return X_stacked
def predict(self, lZ):
lZ = _joblib_transform_multiling(
self.standardizer.transform, lZ, n_jobs=self.n_jobs

View File

@ -140,6 +140,7 @@ class TextualTransformerGen(ViewGen, TransformerGen):
evaluate_step=self.evaluate_step,
patience=self.patience,
experiment_name=experiment_name,
checkpoint_path="models/vgfs/transformer",
)
trainer.train(
train_dataloader=tra_dataloader,

View File

@ -124,6 +124,7 @@ class VisualTransformerGen(ViewGen, TransformerGen):
evaluate_step=self.evaluate_step,
patience=self.patience,
experiment_name=experiment_name,
checkpoint_path="models/vgfs/transformer",
)
trainer.train(