attention-based aggregation function, first implementation, some hard-coded parameters

This commit is contained in:
Andrea Pedrotti 2023-02-10 18:29:58 +01:00
parent 2a42b21ac9
commit 13ada46c34
5 changed files with 184 additions and 19 deletions

View File

@ -6,7 +6,7 @@ sys.path.append(os.path.join(os.getcwd(), "gfun"))
import pickle import pickle
import numpy as np import numpy as np
from vgfs.commons import TfidfVectorizerMultilingual from vgfs.commons import TfidfVectorizerMultilingual, AttentionAggregator
from vgfs.learners.svms import MetaClassifier, get_learner from vgfs.learners.svms import MetaClassifier, get_learner
from vgfs.multilingualGen import MultilingualGen from vgfs.multilingualGen import MultilingualGen
from gfun.vgfs.textualTransformerGen import TextualTransformerGen from gfun.vgfs.textualTransformerGen import TextualTransformerGen
@ -44,6 +44,7 @@ class GeneralizedFunnelling:
self.multilingual_vgf = multilingual self.multilingual_vgf = multilingual
self.trasformer_vgf = transformer self.trasformer_vgf = transformer
self.probabilistic = probabilistic self.probabilistic = probabilistic
self.num_labels = 73 # TODO: hard-coded
# ------------------------ # ------------------------
self.langs = langs self.langs = langs
self.embed_dir = embed_dir self.embed_dir = embed_dir
@ -81,7 +82,6 @@ class GeneralizedFunnelling:
self.load_trained self.load_trained
) )
# TODO: config like aggfunc, device, n_jobs, etc # TODO: config like aggfunc, device, n_jobs, etc
return self
if self.posteriors_vgf: if self.posteriors_vgf:
fun = VanillaFunGen( fun = VanillaFunGen(
@ -121,6 +121,15 @@ class GeneralizedFunnelling:
) )
self.first_tier_learners.append(transformer_vgf) self.first_tier_learners.append(transformer_vgf)
if self.aggfunc == "attn":
self.attn_aggregator = AttentionAggregator(
embed_dim=self.get_attn_agg_dim(),
out_dim=self.num_labels,
num_heads=1,
device=self.device,
epochs=self.epochs,
)
self.metaclassifier = MetaClassifier( self.metaclassifier = MetaClassifier(
meta_learner=get_learner(calibrate=True, kernel="rbf"), meta_learner=get_learner(calibrate=True, kernel="rbf"),
meta_parameters=get_params(self.optimc), meta_parameters=get_params(self.optimc),
@ -160,7 +169,7 @@ class GeneralizedFunnelling:
l_posteriors = vgf.fit_transform(lX, lY) l_posteriors = vgf.fit_transform(lX, lY)
projections.append(l_posteriors) projections.append(l_posteriors)
agg = self.aggregate(projections) agg = self.aggregate(projections, lY)
self.metaclassifier.fit(agg, lY) self.metaclassifier.fit(agg, lY)
return self return self
@ -177,15 +186,27 @@ class GeneralizedFunnelling:
def fit_transform(self, lX, lY): def fit_transform(self, lX, lY):
return self.fit(lX, lY).transform(lX) return self.fit(lX, lY).transform(lX)
def aggregate(self, first_tier_projections): def aggregate(self, first_tier_projections, lY=None):
if self.aggfunc == "mean": if self.aggfunc == "mean":
aggregated = self._aggregate_mean(first_tier_projections) aggregated = self._aggregate_mean(first_tier_projections)
elif self.aggfunc == "concat": elif self.aggfunc == "concat":
aggregated = self._aggregate_concat(first_tier_projections) aggregated = self._aggregate_concat(first_tier_projections)
elif self.aggfunc == "attn":
aggregated = self._aggregate_attn(first_tier_projections, lY)
else: else:
raise NotImplementedError raise NotImplementedError
return aggregated return aggregated
def _aggregate_attn(self, first_tier_projections, lY=None):
if lY is None:
# at prediction time
aggregated = self.attn_aggregator.transform(first_tier_projections)
else:
# at training time we must fit the attention layer
self.attn_aggregator.fit(first_tier_projections, lY)
aggregated = self.attn_aggregator.transform(first_tier_projections)
return aggregated
def _aggregate_concat(self, first_tier_projections): def _aggregate_concat(self, first_tier_projections):
aggregated = {} aggregated = {}
for lang in self.langs: for lang in self.langs:
@ -201,7 +222,6 @@ class GeneralizedFunnelling:
for lang, projection in lang_projections.items(): for lang, projection in lang_projections.items():
aggregated[lang] += projection aggregated[lang] += projection
# Computing mean
for lang, projection in aggregated.items(): for lang, projection in aggregated.items():
aggregated[lang] /= len(first_tier_projections) aggregated[lang] /= len(first_tier_projections)
@ -281,6 +301,11 @@ class GeneralizedFunnelling:
vectorizer = pickle.load(f) vectorizer = pickle.load(f)
return first_tier_learners, metaclassifier, vectorizer return first_tier_learners, metaclassifier, vectorizer
def get_attn_agg_dim(self):
# TODO: hardcoded for now
print("\n[NB: ATTN AGGREGATOR DIM HARD-CODED TO 146]\n")
return 146
def get_params(optimc=False): def get_params(optimc=False):
if not optimc: if not optimc:

View File

@ -3,13 +3,18 @@ from collections import defaultdict
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from sklearn.decomposition import TruncatedSVD from sklearn.decomposition import TruncatedSVD
from sklearn.feature_extraction.text import TfidfVectorizer from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.preprocessing import normalize from sklearn.preprocessing import normalize
from torch.optim import AdamW from torch.optim import AdamW
from transformers.modeling_outputs import SequenceClassifierOutput
from evaluation.evaluate import evaluate, log_eval from evaluation.evaluate import evaluate, log_eval
PRINT_ON_EPOCH = 10
def _normalize(lX, l2=True): def _normalize(lX, l2=True):
return {lang: normalize(np.asarray(X)) for lang, X in lX.items()} if l2 else lX return {lang: normalize(np.asarray(X)) for lang, X in lX.items()} if l2 else lX
@ -107,6 +112,7 @@ class Trainer:
evaluate_step, evaluate_step,
patience, patience,
experiment_name, experiment_name,
checkpoint_path,
): ):
self.device = device self.device = device
self.model = model.to(device) self.model = model.to(device)
@ -118,7 +124,7 @@ class Trainer:
self.patience = patience self.patience = patience
self.earlystopping = EarlyStopping( self.earlystopping = EarlyStopping(
patience=patience, patience=patience,
checkpoint_path="models/vgfs/transformer/", checkpoint_path=checkpoint_path,
verbose=True, verbose=True,
experiment_name=experiment_name, experiment_name=experiment_name,
) )
@ -163,9 +169,13 @@ class Trainer:
for b_idx, (x, y, lang) in enumerate(dataloader): for b_idx, (x, y, lang) in enumerate(dataloader):
self.optimizer.zero_grad() self.optimizer.zero_grad()
y_hat = self.model(x.to(self.device)) y_hat = self.model(x.to(self.device))
if isinstance(y_hat, SequenceClassifierOutput):
loss = self.loss_fn(y_hat.logits, y.to(self.device)) loss = self.loss_fn(y_hat.logits, y.to(self.device))
else:
loss = self.loss_fn(y_hat, y.to(self.device))
loss.backward() loss.backward()
self.optimizer.step() self.optimizer.step()
if (epoch + 1) % PRINT_ON_EPOCH == 0:
if b_idx % self.print_steps == 0: if b_idx % self.print_steps == 0:
print(f"Epoch: {epoch+1} Step: {b_idx+1} Loss: {loss:.4f}") print(f"Epoch: {epoch+1} Step: {b_idx+1} Loss: {loss:.4f}")
return self return self
@ -178,8 +188,12 @@ class Trainer:
for b_idx, (x, y, lang) in enumerate(dataloader): for b_idx, (x, y, lang) in enumerate(dataloader):
y_hat = self.model(x.to(self.device)) y_hat = self.model(x.to(self.device))
if isinstance(y_hat, SequenceClassifierOutput):
loss = self.loss_fn(y_hat.logits, y.to(self.device)) loss = self.loss_fn(y_hat.logits, y.to(self.device))
predictions = predict(y_hat.logits, classification_type="multilabel") predictions = predict(y_hat.logits, classification_type="multilabel")
else:
loss = self.loss_fn(y_hat, y.to(self.device))
predictions = predict(y_hat, classification_type="multilabel")
for l, _true, _pred in zip(lang, y, predictions): for l, _true, _pred in zip(lang, y, predictions):
lY[l].append(_true.detach().cpu().numpy()) lY[l].append(_true.detach().cpu().numpy())
@ -240,3 +254,135 @@ class EarlyStopping:
def load_model(self, model): def load_model(self, model):
_checkpoint_dir = os.path.join(self.checkpoint_path, self.experiment_name) _checkpoint_dir = os.path.join(self.checkpoint_path, self.experiment_name)
return model.from_pretrained(_checkpoint_dir) return model.from_pretrained(_checkpoint_dir)
class AttentionModule(nn.Module):
def __init__(self, embed_dim, num_heads, out_dim):
super().__init__()
self.attn = nn.MultiheadAttention(embed_dim, num_heads)
self.linear = nn.Linear(embed_dim, out_dim)
def __call__(self, X):
attn_out, attn_weights = self.attn(query=X, key=X, value=X)
out = self.linear(attn_out)
return out
def transform(self, X):
attn_out, attn_weights = self.attn(query=X, key=X, value=X)
return attn_out
def save_pretrained(self, path):
torch.save(self.state_dict(), f"{path}.pt")
def _wtf(self):
print("wtf")
class AttentionAggregator:
def __init__(self, embed_dim, out_dim, epochs, num_heads=1, device="cpu"):
self.embed_dim = embed_dim
self.num_heads = num_heads
self.device = device
self.epochs = epochs
self.attn = AttentionModule(embed_dim, num_heads, out_dim).to(self.device)
def fit(self, X, Y):
print("- fitting Attention-based aggregating function")
hstacked_X = self.stack(X)
dataset = AggregatorDatasetTorch(hstacked_X, Y)
tra_dataloader = DataLoader(dataset, batch_size=32, shuffle=True)
experiment_name = "attention_aggregator"
trainer = Trainer(
self.attn,
optimizer_name="adamW",
lr=1e-3,
loss_fn=torch.nn.CrossEntropyLoss(),
print_steps=100,
evaluate_step=1000,
patience=10,
experiment_name=experiment_name,
device=self.device,
checkpoint_path="models/aggregator",
)
trainer.train(
train_dataloader=tra_dataloader,
eval_dataloader=tra_dataloader,
epochs=self.epochs,
)
return self
def transform(self, X):
# TODO: implement transform
h_stacked = self.stack(X)
dataset = AggregatorDatasetTorch(h_stacked, lY=None, split="whole")
dataloader = DataLoader(dataset, batch_size=32, shuffle=False)
_embeds = []
l_embeds = defaultdict(list)
self.attn.eval()
with torch.no_grad():
for input_ids, lang in dataloader:
input_ids = input_ids.to(self.device)
out = self.attn.transform(input_ids)
_embeds.append((out.cpu().numpy(), lang))
for embed, lang in _embeds:
for sample_embed, sample_lang in zip(embed, lang):
l_embeds[sample_lang].append(sample_embed)
l_embeds = {lang: np.array(preds) for lang, preds in l_embeds.items()}
return l_embeds
def stack(self, data):
hstack = self._hstack(data)
return hstack
def _hstack(self, data):
_langs = data[0].keys()
l_projections = {}
for l in _langs:
l_projections[l] = torch.tensor(
np.hstack([view[l] for view in data]), dtype=torch.float32
)
return l_projections
def _vstack(self, data):
return torch.vstack()
class AggregatorDatasetTorch(Dataset):
def __init__(self, lX, lY, split="train"):
self.lX = lX
self.lY = lY
self.split = split
self.langs = []
self.init()
def init(self):
self.X = torch.vstack([data for data in self.lX.values()])
if self.split != "whole":
self.Y = torch.vstack([torch.Tensor(data) for data in self.lY.values()])
self.langs = sum(
[
v
for v in {
lang: [lang] * len(data) for lang, data in self.lX.items()
}.values()
],
[],
)
return self
def __len__(self):
return len(self.X)
def __getitem__(self, index):
if self.split == "whole":
return self.X[index], self.langs[index]
return self.X[index], self.Y[index], self.langs[index]

View File

@ -241,14 +241,6 @@ class MetaClassifier:
else: else:
return Z return Z
# def stack(self, lZ, lY=None):
# X_stacked = np.vstack(list(lZ.values()))
# if lY is not None:
# Y_stacked = np.vstack(list(lY.values()))
# return X_stacked, Y_stacked
# else:
# return X_stacked
def predict(self, lZ): def predict(self, lZ):
lZ = _joblib_transform_multiling( lZ = _joblib_transform_multiling(
self.standardizer.transform, lZ, n_jobs=self.n_jobs self.standardizer.transform, lZ, n_jobs=self.n_jobs

View File

@ -140,6 +140,7 @@ class TextualTransformerGen(ViewGen, TransformerGen):
evaluate_step=self.evaluate_step, evaluate_step=self.evaluate_step,
patience=self.patience, patience=self.patience,
experiment_name=experiment_name, experiment_name=experiment_name,
checkpoint_path="models/vgfs/transformer",
) )
trainer.train( trainer.train(
train_dataloader=tra_dataloader, train_dataloader=tra_dataloader,

View File

@ -124,6 +124,7 @@ class VisualTransformerGen(ViewGen, TransformerGen):
evaluate_step=self.evaluate_step, evaluate_step=self.evaluate_step,
patience=self.patience, patience=self.patience,
experiment_name=experiment_name, experiment_name=experiment_name,
checkpoint_path="models/vgfs/transformer",
) )
trainer.train( trainer.train(