attention-based aggregation function, first implementation, some hard-coded parameters
This commit is contained in:
parent
2a42b21ac9
commit
13ada46c34
|
@ -6,7 +6,7 @@ sys.path.append(os.path.join(os.getcwd(), "gfun"))
|
||||||
import pickle
|
import pickle
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from vgfs.commons import TfidfVectorizerMultilingual
|
from vgfs.commons import TfidfVectorizerMultilingual, AttentionAggregator
|
||||||
from vgfs.learners.svms import MetaClassifier, get_learner
|
from vgfs.learners.svms import MetaClassifier, get_learner
|
||||||
from vgfs.multilingualGen import MultilingualGen
|
from vgfs.multilingualGen import MultilingualGen
|
||||||
from gfun.vgfs.textualTransformerGen import TextualTransformerGen
|
from gfun.vgfs.textualTransformerGen import TextualTransformerGen
|
||||||
|
@ -44,6 +44,7 @@ class GeneralizedFunnelling:
|
||||||
self.multilingual_vgf = multilingual
|
self.multilingual_vgf = multilingual
|
||||||
self.trasformer_vgf = transformer
|
self.trasformer_vgf = transformer
|
||||||
self.probabilistic = probabilistic
|
self.probabilistic = probabilistic
|
||||||
|
self.num_labels = 73 # TODO: hard-coded
|
||||||
# ------------------------
|
# ------------------------
|
||||||
self.langs = langs
|
self.langs = langs
|
||||||
self.embed_dir = embed_dir
|
self.embed_dir = embed_dir
|
||||||
|
@ -81,7 +82,6 @@ class GeneralizedFunnelling:
|
||||||
self.load_trained
|
self.load_trained
|
||||||
)
|
)
|
||||||
# TODO: config like aggfunc, device, n_jobs, etc
|
# TODO: config like aggfunc, device, n_jobs, etc
|
||||||
return self
|
|
||||||
|
|
||||||
if self.posteriors_vgf:
|
if self.posteriors_vgf:
|
||||||
fun = VanillaFunGen(
|
fun = VanillaFunGen(
|
||||||
|
@ -121,6 +121,15 @@ class GeneralizedFunnelling:
|
||||||
)
|
)
|
||||||
self.first_tier_learners.append(transformer_vgf)
|
self.first_tier_learners.append(transformer_vgf)
|
||||||
|
|
||||||
|
if self.aggfunc == "attn":
|
||||||
|
self.attn_aggregator = AttentionAggregator(
|
||||||
|
embed_dim=self.get_attn_agg_dim(),
|
||||||
|
out_dim=self.num_labels,
|
||||||
|
num_heads=1,
|
||||||
|
device=self.device,
|
||||||
|
epochs=self.epochs,
|
||||||
|
)
|
||||||
|
|
||||||
self.metaclassifier = MetaClassifier(
|
self.metaclassifier = MetaClassifier(
|
||||||
meta_learner=get_learner(calibrate=True, kernel="rbf"),
|
meta_learner=get_learner(calibrate=True, kernel="rbf"),
|
||||||
meta_parameters=get_params(self.optimc),
|
meta_parameters=get_params(self.optimc),
|
||||||
|
@ -160,7 +169,7 @@ class GeneralizedFunnelling:
|
||||||
l_posteriors = vgf.fit_transform(lX, lY)
|
l_posteriors = vgf.fit_transform(lX, lY)
|
||||||
projections.append(l_posteriors)
|
projections.append(l_posteriors)
|
||||||
|
|
||||||
agg = self.aggregate(projections)
|
agg = self.aggregate(projections, lY)
|
||||||
self.metaclassifier.fit(agg, lY)
|
self.metaclassifier.fit(agg, lY)
|
||||||
|
|
||||||
return self
|
return self
|
||||||
|
@ -177,15 +186,27 @@ class GeneralizedFunnelling:
|
||||||
def fit_transform(self, lX, lY):
|
def fit_transform(self, lX, lY):
|
||||||
return self.fit(lX, lY).transform(lX)
|
return self.fit(lX, lY).transform(lX)
|
||||||
|
|
||||||
def aggregate(self, first_tier_projections):
|
def aggregate(self, first_tier_projections, lY=None):
|
||||||
if self.aggfunc == "mean":
|
if self.aggfunc == "mean":
|
||||||
aggregated = self._aggregate_mean(first_tier_projections)
|
aggregated = self._aggregate_mean(first_tier_projections)
|
||||||
elif self.aggfunc == "concat":
|
elif self.aggfunc == "concat":
|
||||||
aggregated = self._aggregate_concat(first_tier_projections)
|
aggregated = self._aggregate_concat(first_tier_projections)
|
||||||
|
elif self.aggfunc == "attn":
|
||||||
|
aggregated = self._aggregate_attn(first_tier_projections, lY)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
return aggregated
|
return aggregated
|
||||||
|
|
||||||
|
def _aggregate_attn(self, first_tier_projections, lY=None):
|
||||||
|
if lY is None:
|
||||||
|
# at prediction time
|
||||||
|
aggregated = self.attn_aggregator.transform(first_tier_projections)
|
||||||
|
else:
|
||||||
|
# at training time we must fit the attention layer
|
||||||
|
self.attn_aggregator.fit(first_tier_projections, lY)
|
||||||
|
aggregated = self.attn_aggregator.transform(first_tier_projections)
|
||||||
|
return aggregated
|
||||||
|
|
||||||
def _aggregate_concat(self, first_tier_projections):
|
def _aggregate_concat(self, first_tier_projections):
|
||||||
aggregated = {}
|
aggregated = {}
|
||||||
for lang in self.langs:
|
for lang in self.langs:
|
||||||
|
@ -201,7 +222,6 @@ class GeneralizedFunnelling:
|
||||||
for lang, projection in lang_projections.items():
|
for lang, projection in lang_projections.items():
|
||||||
aggregated[lang] += projection
|
aggregated[lang] += projection
|
||||||
|
|
||||||
# Computing mean
|
|
||||||
for lang, projection in aggregated.items():
|
for lang, projection in aggregated.items():
|
||||||
aggregated[lang] /= len(first_tier_projections)
|
aggregated[lang] /= len(first_tier_projections)
|
||||||
|
|
||||||
|
@ -281,6 +301,11 @@ class GeneralizedFunnelling:
|
||||||
vectorizer = pickle.load(f)
|
vectorizer = pickle.load(f)
|
||||||
return first_tier_learners, metaclassifier, vectorizer
|
return first_tier_learners, metaclassifier, vectorizer
|
||||||
|
|
||||||
|
def get_attn_agg_dim(self):
|
||||||
|
# TODO: hardcoded for now
|
||||||
|
print("\n[NB: ATTN AGGREGATOR DIM HARD-CODED TO 146]\n")
|
||||||
|
return 146
|
||||||
|
|
||||||
|
|
||||||
def get_params(optimc=False):
|
def get_params(optimc=False):
|
||||||
if not optimc:
|
if not optimc:
|
||||||
|
|
|
@ -3,13 +3,18 @@ from collections import defaultdict
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from torch.utils.data import DataLoader, Dataset
|
||||||
from sklearn.decomposition import TruncatedSVD
|
from sklearn.decomposition import TruncatedSVD
|
||||||
from sklearn.feature_extraction.text import TfidfVectorizer
|
from sklearn.feature_extraction.text import TfidfVectorizer
|
||||||
from sklearn.preprocessing import normalize
|
from sklearn.preprocessing import normalize
|
||||||
from torch.optim import AdamW
|
from torch.optim import AdamW
|
||||||
|
from transformers.modeling_outputs import SequenceClassifierOutput
|
||||||
|
|
||||||
from evaluation.evaluate import evaluate, log_eval
|
from evaluation.evaluate import evaluate, log_eval
|
||||||
|
|
||||||
|
PRINT_ON_EPOCH = 10
|
||||||
|
|
||||||
|
|
||||||
def _normalize(lX, l2=True):
|
def _normalize(lX, l2=True):
|
||||||
return {lang: normalize(np.asarray(X)) for lang, X in lX.items()} if l2 else lX
|
return {lang: normalize(np.asarray(X)) for lang, X in lX.items()} if l2 else lX
|
||||||
|
@ -107,6 +112,7 @@ class Trainer:
|
||||||
evaluate_step,
|
evaluate_step,
|
||||||
patience,
|
patience,
|
||||||
experiment_name,
|
experiment_name,
|
||||||
|
checkpoint_path,
|
||||||
):
|
):
|
||||||
self.device = device
|
self.device = device
|
||||||
self.model = model.to(device)
|
self.model = model.to(device)
|
||||||
|
@ -118,7 +124,7 @@ class Trainer:
|
||||||
self.patience = patience
|
self.patience = patience
|
||||||
self.earlystopping = EarlyStopping(
|
self.earlystopping = EarlyStopping(
|
||||||
patience=patience,
|
patience=patience,
|
||||||
checkpoint_path="models/vgfs/transformer/",
|
checkpoint_path=checkpoint_path,
|
||||||
verbose=True,
|
verbose=True,
|
||||||
experiment_name=experiment_name,
|
experiment_name=experiment_name,
|
||||||
)
|
)
|
||||||
|
@ -163,9 +169,13 @@ class Trainer:
|
||||||
for b_idx, (x, y, lang) in enumerate(dataloader):
|
for b_idx, (x, y, lang) in enumerate(dataloader):
|
||||||
self.optimizer.zero_grad()
|
self.optimizer.zero_grad()
|
||||||
y_hat = self.model(x.to(self.device))
|
y_hat = self.model(x.to(self.device))
|
||||||
|
if isinstance(y_hat, SequenceClassifierOutput):
|
||||||
loss = self.loss_fn(y_hat.logits, y.to(self.device))
|
loss = self.loss_fn(y_hat.logits, y.to(self.device))
|
||||||
|
else:
|
||||||
|
loss = self.loss_fn(y_hat, y.to(self.device))
|
||||||
loss.backward()
|
loss.backward()
|
||||||
self.optimizer.step()
|
self.optimizer.step()
|
||||||
|
if (epoch + 1) % PRINT_ON_EPOCH == 0:
|
||||||
if b_idx % self.print_steps == 0:
|
if b_idx % self.print_steps == 0:
|
||||||
print(f"Epoch: {epoch+1} Step: {b_idx+1} Loss: {loss:.4f}")
|
print(f"Epoch: {epoch+1} Step: {b_idx+1} Loss: {loss:.4f}")
|
||||||
return self
|
return self
|
||||||
|
@ -178,8 +188,12 @@ class Trainer:
|
||||||
|
|
||||||
for b_idx, (x, y, lang) in enumerate(dataloader):
|
for b_idx, (x, y, lang) in enumerate(dataloader):
|
||||||
y_hat = self.model(x.to(self.device))
|
y_hat = self.model(x.to(self.device))
|
||||||
|
if isinstance(y_hat, SequenceClassifierOutput):
|
||||||
loss = self.loss_fn(y_hat.logits, y.to(self.device))
|
loss = self.loss_fn(y_hat.logits, y.to(self.device))
|
||||||
predictions = predict(y_hat.logits, classification_type="multilabel")
|
predictions = predict(y_hat.logits, classification_type="multilabel")
|
||||||
|
else:
|
||||||
|
loss = self.loss_fn(y_hat, y.to(self.device))
|
||||||
|
predictions = predict(y_hat, classification_type="multilabel")
|
||||||
|
|
||||||
for l, _true, _pred in zip(lang, y, predictions):
|
for l, _true, _pred in zip(lang, y, predictions):
|
||||||
lY[l].append(_true.detach().cpu().numpy())
|
lY[l].append(_true.detach().cpu().numpy())
|
||||||
|
@ -240,3 +254,135 @@ class EarlyStopping:
|
||||||
def load_model(self, model):
|
def load_model(self, model):
|
||||||
_checkpoint_dir = os.path.join(self.checkpoint_path, self.experiment_name)
|
_checkpoint_dir = os.path.join(self.checkpoint_path, self.experiment_name)
|
||||||
return model.from_pretrained(_checkpoint_dir)
|
return model.from_pretrained(_checkpoint_dir)
|
||||||
|
|
||||||
|
|
||||||
|
class AttentionModule(nn.Module):
|
||||||
|
def __init__(self, embed_dim, num_heads, out_dim):
|
||||||
|
super().__init__()
|
||||||
|
self.attn = nn.MultiheadAttention(embed_dim, num_heads)
|
||||||
|
self.linear = nn.Linear(embed_dim, out_dim)
|
||||||
|
|
||||||
|
def __call__(self, X):
|
||||||
|
attn_out, attn_weights = self.attn(query=X, key=X, value=X)
|
||||||
|
out = self.linear(attn_out)
|
||||||
|
return out
|
||||||
|
|
||||||
|
def transform(self, X):
|
||||||
|
attn_out, attn_weights = self.attn(query=X, key=X, value=X)
|
||||||
|
return attn_out
|
||||||
|
|
||||||
|
def save_pretrained(self, path):
|
||||||
|
torch.save(self.state_dict(), f"{path}.pt")
|
||||||
|
|
||||||
|
def _wtf(self):
|
||||||
|
print("wtf")
|
||||||
|
|
||||||
|
|
||||||
|
class AttentionAggregator:
|
||||||
|
def __init__(self, embed_dim, out_dim, epochs, num_heads=1, device="cpu"):
|
||||||
|
self.embed_dim = embed_dim
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.device = device
|
||||||
|
self.epochs = epochs
|
||||||
|
self.attn = AttentionModule(embed_dim, num_heads, out_dim).to(self.device)
|
||||||
|
|
||||||
|
def fit(self, X, Y):
|
||||||
|
print("- fitting Attention-based aggregating function")
|
||||||
|
hstacked_X = self.stack(X)
|
||||||
|
|
||||||
|
dataset = AggregatorDatasetTorch(hstacked_X, Y)
|
||||||
|
tra_dataloader = DataLoader(dataset, batch_size=32, shuffle=True)
|
||||||
|
|
||||||
|
experiment_name = "attention_aggregator"
|
||||||
|
trainer = Trainer(
|
||||||
|
self.attn,
|
||||||
|
optimizer_name="adamW",
|
||||||
|
lr=1e-3,
|
||||||
|
loss_fn=torch.nn.CrossEntropyLoss(),
|
||||||
|
print_steps=100,
|
||||||
|
evaluate_step=1000,
|
||||||
|
patience=10,
|
||||||
|
experiment_name=experiment_name,
|
||||||
|
device=self.device,
|
||||||
|
checkpoint_path="models/aggregator",
|
||||||
|
)
|
||||||
|
|
||||||
|
trainer.train(
|
||||||
|
train_dataloader=tra_dataloader,
|
||||||
|
eval_dataloader=tra_dataloader,
|
||||||
|
epochs=self.epochs,
|
||||||
|
)
|
||||||
|
return self
|
||||||
|
|
||||||
|
def transform(self, X):
|
||||||
|
# TODO: implement transform
|
||||||
|
h_stacked = self.stack(X)
|
||||||
|
dataset = AggregatorDatasetTorch(h_stacked, lY=None, split="whole")
|
||||||
|
dataloader = DataLoader(dataset, batch_size=32, shuffle=False)
|
||||||
|
|
||||||
|
_embeds = []
|
||||||
|
l_embeds = defaultdict(list)
|
||||||
|
|
||||||
|
self.attn.eval()
|
||||||
|
with torch.no_grad():
|
||||||
|
for input_ids, lang in dataloader:
|
||||||
|
input_ids = input_ids.to(self.device)
|
||||||
|
out = self.attn.transform(input_ids)
|
||||||
|
_embeds.append((out.cpu().numpy(), lang))
|
||||||
|
|
||||||
|
for embed, lang in _embeds:
|
||||||
|
for sample_embed, sample_lang in zip(embed, lang):
|
||||||
|
l_embeds[sample_lang].append(sample_embed)
|
||||||
|
|
||||||
|
l_embeds = {lang: np.array(preds) for lang, preds in l_embeds.items()}
|
||||||
|
|
||||||
|
return l_embeds
|
||||||
|
|
||||||
|
def stack(self, data):
|
||||||
|
hstack = self._hstack(data)
|
||||||
|
return hstack
|
||||||
|
|
||||||
|
def _hstack(self, data):
|
||||||
|
_langs = data[0].keys()
|
||||||
|
l_projections = {}
|
||||||
|
for l in _langs:
|
||||||
|
l_projections[l] = torch.tensor(
|
||||||
|
np.hstack([view[l] for view in data]), dtype=torch.float32
|
||||||
|
)
|
||||||
|
return l_projections
|
||||||
|
|
||||||
|
def _vstack(self, data):
|
||||||
|
return torch.vstack()
|
||||||
|
|
||||||
|
|
||||||
|
class AggregatorDatasetTorch(Dataset):
|
||||||
|
def __init__(self, lX, lY, split="train"):
|
||||||
|
self.lX = lX
|
||||||
|
self.lY = lY
|
||||||
|
self.split = split
|
||||||
|
self.langs = []
|
||||||
|
self.init()
|
||||||
|
|
||||||
|
def init(self):
|
||||||
|
self.X = torch.vstack([data for data in self.lX.values()])
|
||||||
|
if self.split != "whole":
|
||||||
|
self.Y = torch.vstack([torch.Tensor(data) for data in self.lY.values()])
|
||||||
|
self.langs = sum(
|
||||||
|
[
|
||||||
|
v
|
||||||
|
for v in {
|
||||||
|
lang: [lang] * len(data) for lang, data in self.lX.items()
|
||||||
|
}.values()
|
||||||
|
],
|
||||||
|
[],
|
||||||
|
)
|
||||||
|
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.X)
|
||||||
|
|
||||||
|
def __getitem__(self, index):
|
||||||
|
if self.split == "whole":
|
||||||
|
return self.X[index], self.langs[index]
|
||||||
|
return self.X[index], self.Y[index], self.langs[index]
|
||||||
|
|
|
@ -241,14 +241,6 @@ class MetaClassifier:
|
||||||
else:
|
else:
|
||||||
return Z
|
return Z
|
||||||
|
|
||||||
# def stack(self, lZ, lY=None):
|
|
||||||
# X_stacked = np.vstack(list(lZ.values()))
|
|
||||||
# if lY is not None:
|
|
||||||
# Y_stacked = np.vstack(list(lY.values()))
|
|
||||||
# return X_stacked, Y_stacked
|
|
||||||
# else:
|
|
||||||
# return X_stacked
|
|
||||||
|
|
||||||
def predict(self, lZ):
|
def predict(self, lZ):
|
||||||
lZ = _joblib_transform_multiling(
|
lZ = _joblib_transform_multiling(
|
||||||
self.standardizer.transform, lZ, n_jobs=self.n_jobs
|
self.standardizer.transform, lZ, n_jobs=self.n_jobs
|
||||||
|
|
|
@ -140,6 +140,7 @@ class TextualTransformerGen(ViewGen, TransformerGen):
|
||||||
evaluate_step=self.evaluate_step,
|
evaluate_step=self.evaluate_step,
|
||||||
patience=self.patience,
|
patience=self.patience,
|
||||||
experiment_name=experiment_name,
|
experiment_name=experiment_name,
|
||||||
|
checkpoint_path="models/vgfs/transformer",
|
||||||
)
|
)
|
||||||
trainer.train(
|
trainer.train(
|
||||||
train_dataloader=tra_dataloader,
|
train_dataloader=tra_dataloader,
|
||||||
|
|
|
@ -124,6 +124,7 @@ class VisualTransformerGen(ViewGen, TransformerGen):
|
||||||
evaluate_step=self.evaluate_step,
|
evaluate_step=self.evaluate_step,
|
||||||
patience=self.patience,
|
patience=self.patience,
|
||||||
experiment_name=experiment_name,
|
experiment_name=experiment_name,
|
||||||
|
checkpoint_path="models/vgfs/transformer",
|
||||||
)
|
)
|
||||||
|
|
||||||
trainer.train(
|
trainer.train(
|
||||||
|
|
Loading…
Reference in New Issue