attention-based aggregation function, first implementation, some hard-coded parameters
This commit is contained in:
parent
2a42b21ac9
commit
13ada46c34
|
@ -6,7 +6,7 @@ sys.path.append(os.path.join(os.getcwd(), "gfun"))
|
|||
import pickle
|
||||
|
||||
import numpy as np
|
||||
from vgfs.commons import TfidfVectorizerMultilingual
|
||||
from vgfs.commons import TfidfVectorizerMultilingual, AttentionAggregator
|
||||
from vgfs.learners.svms import MetaClassifier, get_learner
|
||||
from vgfs.multilingualGen import MultilingualGen
|
||||
from gfun.vgfs.textualTransformerGen import TextualTransformerGen
|
||||
|
@ -44,6 +44,7 @@ class GeneralizedFunnelling:
|
|||
self.multilingual_vgf = multilingual
|
||||
self.trasformer_vgf = transformer
|
||||
self.probabilistic = probabilistic
|
||||
self.num_labels = 73 # TODO: hard-coded
|
||||
# ------------------------
|
||||
self.langs = langs
|
||||
self.embed_dir = embed_dir
|
||||
|
@ -81,7 +82,6 @@ class GeneralizedFunnelling:
|
|||
self.load_trained
|
||||
)
|
||||
# TODO: config like aggfunc, device, n_jobs, etc
|
||||
return self
|
||||
|
||||
if self.posteriors_vgf:
|
||||
fun = VanillaFunGen(
|
||||
|
@ -121,6 +121,15 @@ class GeneralizedFunnelling:
|
|||
)
|
||||
self.first_tier_learners.append(transformer_vgf)
|
||||
|
||||
if self.aggfunc == "attn":
|
||||
self.attn_aggregator = AttentionAggregator(
|
||||
embed_dim=self.get_attn_agg_dim(),
|
||||
out_dim=self.num_labels,
|
||||
num_heads=1,
|
||||
device=self.device,
|
||||
epochs=self.epochs,
|
||||
)
|
||||
|
||||
self.metaclassifier = MetaClassifier(
|
||||
meta_learner=get_learner(calibrate=True, kernel="rbf"),
|
||||
meta_parameters=get_params(self.optimc),
|
||||
|
@ -160,7 +169,7 @@ class GeneralizedFunnelling:
|
|||
l_posteriors = vgf.fit_transform(lX, lY)
|
||||
projections.append(l_posteriors)
|
||||
|
||||
agg = self.aggregate(projections)
|
||||
agg = self.aggregate(projections, lY)
|
||||
self.metaclassifier.fit(agg, lY)
|
||||
|
||||
return self
|
||||
|
@ -177,15 +186,27 @@ class GeneralizedFunnelling:
|
|||
def fit_transform(self, lX, lY):
|
||||
return self.fit(lX, lY).transform(lX)
|
||||
|
||||
def aggregate(self, first_tier_projections):
|
||||
def aggregate(self, first_tier_projections, lY=None):
|
||||
if self.aggfunc == "mean":
|
||||
aggregated = self._aggregate_mean(first_tier_projections)
|
||||
elif self.aggfunc == "concat":
|
||||
aggregated = self._aggregate_concat(first_tier_projections)
|
||||
elif self.aggfunc == "attn":
|
||||
aggregated = self._aggregate_attn(first_tier_projections, lY)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
return aggregated
|
||||
|
||||
def _aggregate_attn(self, first_tier_projections, lY=None):
|
||||
if lY is None:
|
||||
# at prediction time
|
||||
aggregated = self.attn_aggregator.transform(first_tier_projections)
|
||||
else:
|
||||
# at training time we must fit the attention layer
|
||||
self.attn_aggregator.fit(first_tier_projections, lY)
|
||||
aggregated = self.attn_aggregator.transform(first_tier_projections)
|
||||
return aggregated
|
||||
|
||||
def _aggregate_concat(self, first_tier_projections):
|
||||
aggregated = {}
|
||||
for lang in self.langs:
|
||||
|
@ -201,7 +222,6 @@ class GeneralizedFunnelling:
|
|||
for lang, projection in lang_projections.items():
|
||||
aggregated[lang] += projection
|
||||
|
||||
# Computing mean
|
||||
for lang, projection in aggregated.items():
|
||||
aggregated[lang] /= len(first_tier_projections)
|
||||
|
||||
|
@ -281,6 +301,11 @@ class GeneralizedFunnelling:
|
|||
vectorizer = pickle.load(f)
|
||||
return first_tier_learners, metaclassifier, vectorizer
|
||||
|
||||
def get_attn_agg_dim(self):
|
||||
# TODO: hardcoded for now
|
||||
print("\n[NB: ATTN AGGREGATOR DIM HARD-CODED TO 146]\n")
|
||||
return 146
|
||||
|
||||
|
||||
def get_params(optimc=False):
|
||||
if not optimc:
|
||||
|
|
|
@ -3,13 +3,18 @@ from collections import defaultdict
|
|||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
from sklearn.decomposition import TruncatedSVD
|
||||
from sklearn.feature_extraction.text import TfidfVectorizer
|
||||
from sklearn.preprocessing import normalize
|
||||
from torch.optim import AdamW
|
||||
from transformers.modeling_outputs import SequenceClassifierOutput
|
||||
|
||||
from evaluation.evaluate import evaluate, log_eval
|
||||
|
||||
PRINT_ON_EPOCH = 10
|
||||
|
||||
|
||||
def _normalize(lX, l2=True):
|
||||
return {lang: normalize(np.asarray(X)) for lang, X in lX.items()} if l2 else lX
|
||||
|
@ -107,6 +112,7 @@ class Trainer:
|
|||
evaluate_step,
|
||||
patience,
|
||||
experiment_name,
|
||||
checkpoint_path,
|
||||
):
|
||||
self.device = device
|
||||
self.model = model.to(device)
|
||||
|
@ -118,7 +124,7 @@ class Trainer:
|
|||
self.patience = patience
|
||||
self.earlystopping = EarlyStopping(
|
||||
patience=patience,
|
||||
checkpoint_path="models/vgfs/transformer/",
|
||||
checkpoint_path=checkpoint_path,
|
||||
verbose=True,
|
||||
experiment_name=experiment_name,
|
||||
)
|
||||
|
@ -163,11 +169,15 @@ class Trainer:
|
|||
for b_idx, (x, y, lang) in enumerate(dataloader):
|
||||
self.optimizer.zero_grad()
|
||||
y_hat = self.model(x.to(self.device))
|
||||
loss = self.loss_fn(y_hat.logits, y.to(self.device))
|
||||
if isinstance(y_hat, SequenceClassifierOutput):
|
||||
loss = self.loss_fn(y_hat.logits, y.to(self.device))
|
||||
else:
|
||||
loss = self.loss_fn(y_hat, y.to(self.device))
|
||||
loss.backward()
|
||||
self.optimizer.step()
|
||||
if b_idx % self.print_steps == 0:
|
||||
print(f"Epoch: {epoch+1} Step: {b_idx+1} Loss: {loss:.4f}")
|
||||
if (epoch + 1) % PRINT_ON_EPOCH == 0:
|
||||
if b_idx % self.print_steps == 0:
|
||||
print(f"Epoch: {epoch+1} Step: {b_idx+1} Loss: {loss:.4f}")
|
||||
return self
|
||||
|
||||
def evaluate(self, dataloader):
|
||||
|
@ -178,8 +188,12 @@ class Trainer:
|
|||
|
||||
for b_idx, (x, y, lang) in enumerate(dataloader):
|
||||
y_hat = self.model(x.to(self.device))
|
||||
loss = self.loss_fn(y_hat.logits, y.to(self.device))
|
||||
predictions = predict(y_hat.logits, classification_type="multilabel")
|
||||
if isinstance(y_hat, SequenceClassifierOutput):
|
||||
loss = self.loss_fn(y_hat.logits, y.to(self.device))
|
||||
predictions = predict(y_hat.logits, classification_type="multilabel")
|
||||
else:
|
||||
loss = self.loss_fn(y_hat, y.to(self.device))
|
||||
predictions = predict(y_hat, classification_type="multilabel")
|
||||
|
||||
for l, _true, _pred in zip(lang, y, predictions):
|
||||
lY[l].append(_true.detach().cpu().numpy())
|
||||
|
@ -240,3 +254,135 @@ class EarlyStopping:
|
|||
def load_model(self, model):
|
||||
_checkpoint_dir = os.path.join(self.checkpoint_path, self.experiment_name)
|
||||
return model.from_pretrained(_checkpoint_dir)
|
||||
|
||||
|
||||
class AttentionModule(nn.Module):
|
||||
def __init__(self, embed_dim, num_heads, out_dim):
|
||||
super().__init__()
|
||||
self.attn = nn.MultiheadAttention(embed_dim, num_heads)
|
||||
self.linear = nn.Linear(embed_dim, out_dim)
|
||||
|
||||
def __call__(self, X):
|
||||
attn_out, attn_weights = self.attn(query=X, key=X, value=X)
|
||||
out = self.linear(attn_out)
|
||||
return out
|
||||
|
||||
def transform(self, X):
|
||||
attn_out, attn_weights = self.attn(query=X, key=X, value=X)
|
||||
return attn_out
|
||||
|
||||
def save_pretrained(self, path):
|
||||
torch.save(self.state_dict(), f"{path}.pt")
|
||||
|
||||
def _wtf(self):
|
||||
print("wtf")
|
||||
|
||||
|
||||
class AttentionAggregator:
|
||||
def __init__(self, embed_dim, out_dim, epochs, num_heads=1, device="cpu"):
|
||||
self.embed_dim = embed_dim
|
||||
self.num_heads = num_heads
|
||||
self.device = device
|
||||
self.epochs = epochs
|
||||
self.attn = AttentionModule(embed_dim, num_heads, out_dim).to(self.device)
|
||||
|
||||
def fit(self, X, Y):
|
||||
print("- fitting Attention-based aggregating function")
|
||||
hstacked_X = self.stack(X)
|
||||
|
||||
dataset = AggregatorDatasetTorch(hstacked_X, Y)
|
||||
tra_dataloader = DataLoader(dataset, batch_size=32, shuffle=True)
|
||||
|
||||
experiment_name = "attention_aggregator"
|
||||
trainer = Trainer(
|
||||
self.attn,
|
||||
optimizer_name="adamW",
|
||||
lr=1e-3,
|
||||
loss_fn=torch.nn.CrossEntropyLoss(),
|
||||
print_steps=100,
|
||||
evaluate_step=1000,
|
||||
patience=10,
|
||||
experiment_name=experiment_name,
|
||||
device=self.device,
|
||||
checkpoint_path="models/aggregator",
|
||||
)
|
||||
|
||||
trainer.train(
|
||||
train_dataloader=tra_dataloader,
|
||||
eval_dataloader=tra_dataloader,
|
||||
epochs=self.epochs,
|
||||
)
|
||||
return self
|
||||
|
||||
def transform(self, X):
|
||||
# TODO: implement transform
|
||||
h_stacked = self.stack(X)
|
||||
dataset = AggregatorDatasetTorch(h_stacked, lY=None, split="whole")
|
||||
dataloader = DataLoader(dataset, batch_size=32, shuffle=False)
|
||||
|
||||
_embeds = []
|
||||
l_embeds = defaultdict(list)
|
||||
|
||||
self.attn.eval()
|
||||
with torch.no_grad():
|
||||
for input_ids, lang in dataloader:
|
||||
input_ids = input_ids.to(self.device)
|
||||
out = self.attn.transform(input_ids)
|
||||
_embeds.append((out.cpu().numpy(), lang))
|
||||
|
||||
for embed, lang in _embeds:
|
||||
for sample_embed, sample_lang in zip(embed, lang):
|
||||
l_embeds[sample_lang].append(sample_embed)
|
||||
|
||||
l_embeds = {lang: np.array(preds) for lang, preds in l_embeds.items()}
|
||||
|
||||
return l_embeds
|
||||
|
||||
def stack(self, data):
|
||||
hstack = self._hstack(data)
|
||||
return hstack
|
||||
|
||||
def _hstack(self, data):
|
||||
_langs = data[0].keys()
|
||||
l_projections = {}
|
||||
for l in _langs:
|
||||
l_projections[l] = torch.tensor(
|
||||
np.hstack([view[l] for view in data]), dtype=torch.float32
|
||||
)
|
||||
return l_projections
|
||||
|
||||
def _vstack(self, data):
|
||||
return torch.vstack()
|
||||
|
||||
|
||||
class AggregatorDatasetTorch(Dataset):
|
||||
def __init__(self, lX, lY, split="train"):
|
||||
self.lX = lX
|
||||
self.lY = lY
|
||||
self.split = split
|
||||
self.langs = []
|
||||
self.init()
|
||||
|
||||
def init(self):
|
||||
self.X = torch.vstack([data for data in self.lX.values()])
|
||||
if self.split != "whole":
|
||||
self.Y = torch.vstack([torch.Tensor(data) for data in self.lY.values()])
|
||||
self.langs = sum(
|
||||
[
|
||||
v
|
||||
for v in {
|
||||
lang: [lang] * len(data) for lang, data in self.lX.items()
|
||||
}.values()
|
||||
],
|
||||
[],
|
||||
)
|
||||
|
||||
return self
|
||||
|
||||
def __len__(self):
|
||||
return len(self.X)
|
||||
|
||||
def __getitem__(self, index):
|
||||
if self.split == "whole":
|
||||
return self.X[index], self.langs[index]
|
||||
return self.X[index], self.Y[index], self.langs[index]
|
||||
|
|
|
@ -241,14 +241,6 @@ class MetaClassifier:
|
|||
else:
|
||||
return Z
|
||||
|
||||
# def stack(self, lZ, lY=None):
|
||||
# X_stacked = np.vstack(list(lZ.values()))
|
||||
# if lY is not None:
|
||||
# Y_stacked = np.vstack(list(lY.values()))
|
||||
# return X_stacked, Y_stacked
|
||||
# else:
|
||||
# return X_stacked
|
||||
|
||||
def predict(self, lZ):
|
||||
lZ = _joblib_transform_multiling(
|
||||
self.standardizer.transform, lZ, n_jobs=self.n_jobs
|
||||
|
|
|
@ -140,6 +140,7 @@ class TextualTransformerGen(ViewGen, TransformerGen):
|
|||
evaluate_step=self.evaluate_step,
|
||||
patience=self.patience,
|
||||
experiment_name=experiment_name,
|
||||
checkpoint_path="models/vgfs/transformer",
|
||||
)
|
||||
trainer.train(
|
||||
train_dataloader=tra_dataloader,
|
||||
|
|
|
@ -124,6 +124,7 @@ class VisualTransformerGen(ViewGen, TransformerGen):
|
|||
evaluate_step=self.evaluate_step,
|
||||
patience=self.patience,
|
||||
experiment_name=experiment_name,
|
||||
checkpoint_path="models/vgfs/transformer",
|
||||
)
|
||||
|
||||
trainer.train(
|
||||
|
|
Loading…
Reference in New Issue