gfun_multimodal/gfun/generalizedFunnelling.py

466 lines
16 KiB
Python

import os
import pickle
import numpy as np
from gfun.vgfs.commons import AttentionAggregator, TfidfVectorizerMultilingual, predict
from gfun.vgfs.learners.svms import MetaClassifier, get_learner
from gfun.vgfs.multilingualGen import MultilingualGen
from gfun.vgfs.textualTransformerGen import TextualTransformerGen
from gfun.vgfs.vanillaFun import VanillaFunGen
from gfun.vgfs.visualTransformerGen import VisualTransformerGen
from gfun.vgfs.wceGen import WceGen
class GeneralizedFunnelling:
def __init__(
self,
posterior,
wce,
multilingual,
textual_transformer,
visual_transformer,
langs,
num_labels,
classification_type,
embed_dir,
n_jobs,
batch_size,
eval_batch_size,
max_length,
textual_lr,
visual_lr,
epochs,
patience,
evaluate_step,
textual_transformer_name,
visual_transformer_name,
optimc,
device,
load_trained,
dataset_name,
probabilistic,
aggfunc,
load_meta,
):
# Setting VFGs -----------
self.posteriors_vgf = posterior
self.wce_vgf = wce
self.multilingual_vgf = multilingual
self.textual_trf_vgf = textual_transformer
self.visual_trf_vgf = visual_transformer
self.probabilistic = probabilistic
self.num_labels = num_labels
self.clf_type = classification_type
# ------------------------
self.langs = langs
self.embed_dir = embed_dir
self.cached = True
# Textual Transformer VGF params ----------
self.textual_trf_name = textual_transformer_name
self.epochs = epochs
self.textual_trf_lr = textual_lr
self.textual_scheduler = "ReduceLROnPlateau"
self.batch_size_trf = batch_size
self.eval_batch_size_trf = eval_batch_size
self.max_length = max_length
self.early_stopping = True
self.patience = patience
self.evaluate_step = evaluate_step
self.device = device
# Visual Transformer VGF params ----------
self.visual_trf_name = visual_transformer_name
self.visual_trf_lr = visual_lr
self.visual_scheduler = "ReduceLROnPlateau"
# Metaclassifier params ------------
self.optimc = optimc
# -------------------
self.vectorizer = TfidfVectorizerMultilingual(sublinear_tf=True, use_idf=True)
self.n_jobs = n_jobs
self.first_tier_learners = []
self.metaclassifier = None
self.aggfunc = aggfunc
self.load_trained = load_trained
self.load_first_tier = (
True # TODO: i guess we're always going to load at least the first tier
)
self.load_meta = load_meta
self.dataset_name = dataset_name
self._init()
def _init(self):
print("\n[Init GeneralizedFunnelling]")
assert not (
self.aggfunc == "mean" and self.probabilistic is False
), "When using averaging aggreagation function probabilistic must be True"
if self.load_trained is not None:
# TODO: clean up this code here
print(
"- loading trained VGFs, metaclassifer and vectorizer"
if self.load_meta
else "- loading trained VGFs and vectorizer"
)
self.first_tier_learners, self.metaclassifier, self.vectorizer = self.load(
self.load_trained,
load_first_tier=self.load_first_tier,
load_meta=self.load_meta,
)
if self.metaclassifier is None:
self.metaclassifier = MetaClassifier(
meta_learner=get_learner(calibrate=True, kernel="rbf"),
meta_parameters=get_params(self.optimc),
n_jobs=self.n_jobs,
)
if "attn" in self.aggfunc:
attn_stacking = self.aggfunc.split("_")[1]
self.attn_aggregator = AttentionAggregator(
embed_dim=self.get_attn_agg_dim(attn_stacking_type=attn_stacking),
out_dim=self.num_labels,
lr=self.textual_trf_lr,
patience=self.patience,
num_heads=1,
device=self.device,
epochs=self.epochs,
attn_stacking_type=attn_stacking,
)
return self
if self.posteriors_vgf:
fun = VanillaFunGen(
base_learner=get_learner(calibrate=True),
n_jobs=self.n_jobs,
)
self.first_tier_learners.append(fun)
if self.multilingual_vgf:
multilingual_vgf = MultilingualGen(
embed_dir=self.embed_dir,
langs=self.langs,
n_jobs=self.n_jobs,
cached=self.cached,
probabilistic=self.probabilistic,
)
self.first_tier_learners.append(multilingual_vgf)
if self.wce_vgf:
wce_vgf = WceGen(n_jobs=self.n_jobs)
self.first_tier_learners.append(wce_vgf)
if self.textual_trf_vgf:
transformer_vgf = TextualTransformerGen(
dataset_name=self.dataset_name,
model_name=self.textual_trf_name,
lr=self.textual_trf_lr,
scheduler=self.textual_scheduler,
epochs=self.epochs,
batch_size=self.batch_size_trf,
batch_size_eval=self.eval_batch_size_trf,
max_length=self.max_length,
print_steps=50,
probabilistic=self.probabilistic,
evaluate_step=self.evaluate_step,
verbose=True,
patience=self.patience,
device=self.device,
classification_type=self.clf_type,
)
self.first_tier_learners.append(transformer_vgf)
if self.visual_trf_vgf:
visual_trasformer_vgf = VisualTransformerGen(
dataset_name=self.dataset_name,
model_name="vit",
lr=self.visual_trf_lr,
scheduler=self.visual_scheduler,
epochs=self.epochs,
batch_size=self.batch_size_trf,
batch_size_eval=self.eval_batch_size_trf,
probabilistic=self.probabilistic,
evaluate_step=self.evaluate_step,
patience=self.patience,
device=self.device,
classification_type=self.clf_type,
)
self.first_tier_learners.append(visual_trasformer_vgf)
if "attn" in self.aggfunc:
attn_stacking = self.aggfunc.split("_")[1]
self.attn_aggregator = AttentionAggregator(
embed_dim=self.get_attn_agg_dim(attn_stacking_type=attn_stacking),
out_dim=self.num_labels,
lr=self.textual_trf_lr,
patience=self.patience,
num_heads=1,
device=self.device,
epochs=self.epochs,
attn_stacking_type=attn_stacking,
)
self.metaclassifier = MetaClassifier(
meta_learner=get_learner(calibrate=True, kernel="rbf"),
meta_parameters=get_params(self.optimc),
n_jobs=self.n_jobs,
)
self._model_id = get_unique_id(
self.dataset_name,
self.posteriors_vgf,
self.multilingual_vgf,
self.wce_vgf,
self.textual_trf_vgf,
self.visual_trf_vgf,
self.aggfunc,
)
print(f"- model id: {self._model_id}")
return self
def init_vgfs_vectorizers(self):
for vgf in self.first_tier_learners:
if isinstance(vgf, (VanillaFunGen, MultilingualGen, WceGen)):
vgf.vectorizer = self.vectorizer
def fit(self, lX, lY):
print("\n[Fitting GeneralizedFunnelling]")
if self.load_trained is not None:
print(
"- loaded first tier learners!"
if self.load_meta is False
else "- loaded trained model!"
)
"""
if we are only loading the first tier, we need to
transform the training data to train the meta-classifier
"""
if self.load_first_tier is True and self.load_meta is False:
projections = []
for vgf in self.first_tier_learners:
l_posteriors = vgf.transform(lX)
projections.append(l_posteriors)
agg = self.aggregate(projections, lY)
self.metaclassifier.fit(agg, lY)
return self
self.vectorizer.fit(lX)
self.init_vgfs_vectorizers()
projections = []
print("- fitting first tier learners")
for vgf in self.first_tier_learners:
l_posteriors = vgf.fit_transform(lX, lY)
projections.append(l_posteriors)
agg = self.aggregate(projections, lY)
self.metaclassifier.fit(agg, lY)
return self
def transform(self, lX):
projections = []
for vgf in self.first_tier_learners:
l_posteriors = vgf.transform(lX)
projections.append(l_posteriors)
agg = self.aggregate(projections)
l_out = self.metaclassifier.predict_proba(agg)
if self.clf_type == "singlelabel":
for lang, preds in l_out.items():
l_out[lang] = predict(preds, clf_type=self.clf_type)
return l_out
def fit_transform(self, lX, lY):
return self.fit(lX, lY).transform(lX)
def aggregate(self, first_tier_projections, lY=None):
if self.aggfunc == "mean":
aggregated = self._aggregate_mean(first_tier_projections)
elif self.aggfunc == "concat":
aggregated = self._aggregate_concat(first_tier_projections)
# elif self.aggfunc == "attn":
elif "attn" in self.aggfunc:
aggregated = self._aggregate_attn(first_tier_projections, lY)
else:
raise NotImplementedError
return aggregated
def _aggregate_attn(self, first_tier_projections, lY=None):
if lY is None:
# at prediction time
aggregated = self.attn_aggregator.transform(first_tier_projections)
else:
# at training time we must fit the attention layer
self.attn_aggregator.fit(first_tier_projections, lY)
aggregated = self.attn_aggregator.transform(first_tier_projections)
return aggregated
def _aggregate_concat(self, first_tier_projections):
aggregated = {}
for lang in self.langs:
aggregated[lang] = np.hstack([v[lang] for v in first_tier_projections])
return aggregated
def _aggregate_mean(self, first_tier_projections):
aggregated = {
lang: np.zeros(data.shape)
for lang, data in first_tier_projections[0].items()
}
for lang_projections in first_tier_projections:
for lang, projection in lang_projections.items():
aggregated[lang] += projection
for lang, projection in aggregated.items():
aggregated[lang] /= len(first_tier_projections)
return aggregated
def get_config(self):
c = {}
for vgf in self.first_tier_learners:
vgf_config = vgf.get_config()
c.update(vgf_config)
gfun_config = {
"id": self._model_id,
"aggfunc": self.aggfunc,
"optimc": self.optimc,
"dataset": self.dataset_name,
}
c["gFun"] = gfun_config
return c
def save(self, save_first_tier=True, save_meta=True):
print(f"- Saving GeneralizedFunnelling model with id: {self._model_id}")
os.makedirs(os.path.join("models", "vectorizer"), exist_ok=True)
with open(
os.path.join("models", "vectorizer", f"vectorizer_{self._model_id}.pkl"),
"wb",
) as f:
pickle.dump(self.vectorizer, f)
if save_first_tier:
self.save_first_tier_learners(model_id=self._model_id)
if save_meta:
with open(
os.path.join("models", "metaclassifier", f"meta_{self._model_id}.pkl"),
"wb",
) as f:
pickle.dump(self.metaclassifier, f)
return
def save_first_tier_learners(self, model_id):
for vgf in self.first_tier_learners:
vgf.save_vgf(model_id=self._model_id)
return self
def load(self, model_id, load_first_tier=True, load_meta=True):
print(f"- loading model id: {model_id}")
first_tier_learners = []
with open(
os.path.join("models", "vectorizer", f"vectorizer_{model_id}.pkl"), "rb"
) as f:
vectorizer = pickle.load(f)
if self.posteriors_vgf:
with open(
os.path.join(
"models", "vgfs", "posterior", f"vanillaFunGen_{model_id}.pkl"
),
"rb",
) as vgf:
first_tier_learners.append(pickle.load(vgf))
if self.multilingual_vgf:
with open(
os.path.join(
"models", "vgfs", "multilingual", f"multilingualGen_{model_id}.pkl"
),
"rb",
) as vgf:
first_tier_learners.append(pickle.load(vgf))
if self.wce_vgf:
with open(
os.path.join(
"models", "vgfs", "wordclass", f"wordClassGen_{model_id}.pkl"
),
"rb",
) as vgf:
first_tier_learners.append(pickle.load(vgf))
if self.textual_trf_vgf:
with open(
os.path.join(
"models", "vgfs", "transformer", f"transformerGen_{model_id}.pkl"
),
"rb",
) as vgf:
first_tier_learners.append(pickle.load(vgf))
if load_meta:
with open(
os.path.join("models", "metaclassifier", f"meta_{model_id}.pkl"), "rb"
) as f:
metaclassifier = pickle.load(f)
else:
metaclassifier = None
return first_tier_learners, metaclassifier, vectorizer
def _load_meta(self):
raise NotImplementedError
def _load_posterior(self):
raise NotImplementedError
def _load_multilingual(self):
raise NotImplementedError
def _load_wce(self):
raise NotImplementedError
def _load_transformer(self):
raise NotImplementedError
def get_attn_agg_dim(self, attn_stacking_type):
if self.probabilistic and "attn" not in self.aggfunc:
return len(self.first_tier_learners) * self.num_labels
elif self.probabilistic and "attn" in self.aggfunc:
if attn_stacking_type == "concat":
return len(self.first_tier_learners) * self.num_labels
elif attn_stacking_type == "mean":
return self.num_labels
else:
raise NotImplementedError
else:
raise NotImplementedError
def get_params(optimc=False):
if not optimc:
return None
c_range = [1e4, 1e3, 1e2, 1e1, 1, 1e-1]
kernel = "rbf"
return [{"kernel": [kernel], "C": c_range, "gamma": ["auto"]}]
def get_unique_id(
dataset_name,
posterior,
multilingual,
wce,
textual_transformer,
visual_transformer,
aggfunc,
):
from datetime import datetime
now = datetime.now().strftime("%y%m%d")
model_id = f"{dataset_name}_"
model_id += "p" if posterior else ""
model_id += "m" if multilingual else ""
model_id += "w" if wce else ""
model_id += "t" if textual_transformer else ""
model_id += "v" if visual_transformer else ""
model_id += f"_{aggfunc}"
return f"{model_id}_{now}"