fixed loading function for Attention-based aggregating function when triggered by EarlyStopper

This commit is contained in:
Andrea Pedrotti 2023-02-13 15:01:50 +01:00
parent 930a6d8275
commit 7ed98346a5
5 changed files with 243 additions and 86 deletions

View File

@ -171,6 +171,9 @@ class MultilingualDataset:
else:
langs = sorted(self.multiling_dataset.keys())
return langs
def num_labels(self):
return self.num_categories()
def num_categories(self):
return self.lYtr()[self.langs()[0]].shape[1]

View File

@ -22,6 +22,7 @@ class GeneralizedFunnelling:
multilingual,
transformer,
langs,
num_labels,
embed_dir,
n_jobs,
batch_size,
@ -37,6 +38,7 @@ class GeneralizedFunnelling:
dataset_name,
probabilistic,
aggfunc,
load_meta,
):
# Setting VFGs -----------
self.posteriors_vgf = posterior
@ -44,7 +46,7 @@ class GeneralizedFunnelling:
self.multilingual_vgf = multilingual
self.trasformer_vgf = transformer
self.probabilistic = probabilistic
self.num_labels = 73 # TODO: hard-coded
self.num_labels = num_labels
# ------------------------
self.langs = langs
self.embed_dir = embed_dir
@ -68,6 +70,10 @@ class GeneralizedFunnelling:
self.metaclassifier = None
self.aggfunc = aggfunc
self.load_trained = load_trained
self.load_first_tier = (
True # TODO: i guess we're always going to load at least the fitst tier
)
self.load_meta = load_meta
self.dataset_name = dataset_name
self._init()
@ -77,11 +83,37 @@ class GeneralizedFunnelling:
self.aggfunc == "mean" and self.probabilistic is False
), "When using averaging aggreagation function probabilistic must be True"
if self.load_trained is not None:
print("- loading trained VGFs, metaclassifer and vectorizer")
self.first_tier_learners, self.metaclassifier, self.vectorizer = self.load(
self.load_trained
# TODO: clean up this code here
print(
"- loading trained VGFs, metaclassifer and vectorizer"
if self.load_meta
else "- loading trained VGFs and vectorizer"
)
# TODO: config like aggfunc, device, n_jobs, etc
self.first_tier_learners, self.metaclassifier, self.vectorizer = self.load(
self.load_trained,
load_first_tier=self.load_first_tier,
load_meta=self.load_meta,
)
if self.metaclassifier is None:
self.metaclassifier = MetaClassifier(
meta_learner=get_learner(calibrate=True, kernel="rbf"),
meta_parameters=get_params(self.optimc),
n_jobs=self.n_jobs,
)
if "attn" in self.aggfunc:
attn_stacking = self.aggfunc.split("_")[1]
self.attn_aggregator = AttentionAggregator(
embed_dim=self.get_attn_agg_dim(attn_stacking_type=attn_stacking),
out_dim=self.num_labels,
lr=self.lr_transformer,
patience=self.patience,
num_heads=1,
device=self.device,
epochs=self.epochs,
attn_stacking_type=attn_stacking,
)
return self
if self.posteriors_vgf:
fun = VanillaFunGen(
@ -112,7 +144,7 @@ class GeneralizedFunnelling:
epochs=self.epochs,
batch_size=self.batch_size_transformer,
max_length=self.max_length,
device="cuda",
device=self.device,
print_steps=50,
probabilistic=self.probabilistic,
evaluate_step=self.evaluate_step,
@ -121,13 +153,17 @@ class GeneralizedFunnelling:
)
self.first_tier_learners.append(transformer_vgf)
if self.aggfunc == "attn":
if "attn" in self.aggfunc:
attn_stacking = self.aggfunc.split("_")[1]
self.attn_aggregator = AttentionAggregator(
embed_dim=self.get_attn_agg_dim(),
out_dim=self.num_labels,
lr=self.lr_transformer,
patience=self.patience,
num_heads=1,
device=self.device,
epochs=self.epochs,
attn_stacking_type=attn_stacking,
)
self.metaclassifier = MetaClassifier(
@ -141,6 +177,7 @@ class GeneralizedFunnelling:
self.multilingual_vgf,
self.wce_vgf,
self.trasformer_vgf,
self.aggfunc,
)
print(f"- model id: {self._model_id}")
return self
@ -153,11 +190,19 @@ class GeneralizedFunnelling:
def fit(self, lX, lY):
print("[Fitting GeneralizedFunnelling]")
if self.load_trained is not None:
print(f"- loaded trained model! Skipping training...")
# TODO: add support to load only the first tier learners while re-training the metaclassifier
load_only_first_tier = False
if load_only_first_tier:
raise NotImplementedError
print(
"- loaded first tier learners!"
if self.load_meta is False
else "- loaded trained model!"
)
if self.load_first_tier is True and self.load_meta is False:
# TODO: clean up this code here
projections = []
for vgf in self.first_tier_learners:
l_posteriors = vgf.transform(lX)
projections.append(l_posteriors)
agg = self.aggregate(projections, lY)
self.metaclassifier.fit(agg, lY)
return self
self.vectorizer.fit(lX)
@ -191,7 +236,8 @@ class GeneralizedFunnelling:
aggregated = self._aggregate_mean(first_tier_projections)
elif self.aggfunc == "concat":
aggregated = self._aggregate_concat(first_tier_projections)
elif self.aggfunc == "attn":
# elif self.aggfunc == "attn":
elif "attn" in self.aggfunc:
aggregated = self._aggregate_attn(first_tier_projections, lY)
else:
raise NotImplementedError
@ -238,27 +284,41 @@ class GeneralizedFunnelling:
print(vgf)
print("-" * 50)
def save(self):
def save(self, save_first_tier=True, save_meta=True):
print(f"- Saving GeneralizedFunnelling model with id: {self._model_id}")
# TODO: save only the first tier learners? what about some model config + sanity checks before loading?
for vgf in self.first_tier_learners:
vgf.save_vgf(model_id=self._model_id)
os.makedirs(os.path.join("models", "metaclassifier"), exist_ok=True)
with open(
os.path.join("models", "metaclassifier", f"meta_{self._model_id}.pkl"), "wb"
) as f:
pickle.dump(self.metaclassifier, f)
os.makedirs(os.path.join("models", "vectorizer"), exist_ok=True)
with open(
os.path.join("models", "vectorizer", f"vectorizer_{self._model_id}.pkl"),
"wb",
) as f:
pickle.dump(self.vectorizer, f)
if save_first_tier:
self.save_first_tier_learners(model_id=self._model_id)
if save_meta:
with open(
os.path.join("models", "metaclassifier", f"meta_{self._model_id}.pkl"),
"wb",
) as f:
pickle.dump(self.metaclassifier, f)
return
def load(self, model_id):
def save_first_tier_learners(self, model_id):
for vgf in self.first_tier_learners:
vgf.save_vgf(model_id=self._model_id)
return self
def load(self, model_id, load_first_tier=True, load_meta=True):
print(f"- loading model id: {model_id}")
first_tier_learners = []
with open(
os.path.join("models", "vectorizer", f"vectorizer_{model_id}.pkl"), "rb"
) as f:
vectorizer = pickle.load(f)
if self.posteriors_vgf:
with open(
os.path.join(
@ -291,20 +351,43 @@ class GeneralizedFunnelling:
"rb",
) as vgf:
first_tier_learners.append(pickle.load(vgf))
with open(
os.path.join("models", "metaclassifier", f"meta_{model_id}.pkl"), "rb"
) as f:
metaclassifier = pickle.load(f)
with open(
os.path.join("models", "vectorizer", f"vectorizer_{model_id}.pkl"), "rb"
) as f:
vectorizer = pickle.load(f)
if load_meta:
with open(
os.path.join("models", "metaclassifier", f"meta_{model_id}.pkl"), "rb"
) as f:
metaclassifier = pickle.load(f)
else:
metaclassifier = None
return first_tier_learners, metaclassifier, vectorizer
def get_attn_agg_dim(self):
# TODO: hardcoded for now
print("\n[NB: ATTN AGGREGATOR DIM HARD-CODED TO 146]\n")
return 146
def _load_meta(self):
raise NotImplementedError
def _load_posterior(self):
raise NotImplementedError
def _load_multilingual(self):
raise NotImplementedError
def _load_wce(self):
raise NotImplementedError
def _load_transformer(self):
raise NotImplementedError
def get_attn_agg_dim(self, attn_stacking_type="concat"):
if self.probabilistic and "attn" not in self.aggfunc:
return len(self.first_tier_learners) * self.num_labels
elif self.probabilistic and "attn" in self.aggfunc:
if attn_stacking_type == "concat":
return len(self.first_tier_learners) * self.num_labels
elif attn_stacking_type == "mean":
return self.num_labels
else:
raise NotImplementedError
else:
raise NotImplementedError
def get_params(optimc=False):
@ -315,7 +398,7 @@ def get_params(optimc=False):
return [{"kernel": [kernel], "C": c_range, "gamma": ["auto"]}]
def get_unique_id(posterior, multilingual, wce, transformer):
def get_unique_id(posterior, multilingual, wce, transformer, aggfunc):
from datetime import datetime
now = datetime.now().strftime("%y%m%d")
@ -324,4 +407,5 @@ def get_unique_id(posterior, multilingual, wce, transformer):
model_id += "m" if multilingual else ""
model_id += "w" if wce else ""
model_id += "t" if transformer else ""
model_id += f"_{aggfunc}"
return f"{model_id}_{now}"

View File

@ -10,6 +10,7 @@ from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.preprocessing import normalize
from torch.optim import AdamW
from transformers.modeling_outputs import SequenceClassifierOutput
from sklearn.model_selection import train_test_split
from evaluation.evaluate import evaluate, log_eval
@ -158,7 +159,6 @@ class Trainer:
self.device
)
break
# TODO: maybe a lower lr?
self.train_epoch(eval_dataloader, epoch=epoch)
print(f"\n- last swipe on eval set")
self.earlystopping.save_model(self.model)
@ -176,7 +176,7 @@ class Trainer:
loss.backward()
self.optimizer.step()
if (epoch + 1) % PRINT_ON_EPOCH == 0:
if b_idx % self.print_steps == 0:
if ((b_idx + 1) % self.print_steps == 0) or b_idx == 0:
print(f"Epoch: {epoch+1} Step: {b_idx+1} Loss: {loss:.4f}")
return self
@ -209,7 +209,6 @@ class Trainer:
class EarlyStopping:
# TODO: add checkpointing + restore model if early stopping + last swipe on validation set
def __init__(
self,
patience,
@ -247,8 +246,8 @@ class EarlyStopping:
return True
def save_model(self, model):
os.makedirs(self.checkpoint_path, exist_ok=True)
_checkpoint_dir = os.path.join(self.checkpoint_path, self.experiment_name)
os.makedirs(_checkpoint_dir, exist_ok=True)
model.save_pretrained(_checkpoint_dir)
def load_model(self, model):
@ -257,51 +256,97 @@ class EarlyStopping:
class AttentionModule(nn.Module):
def __init__(self, embed_dim, num_heads, out_dim):
def __init__(self, embed_dim, num_heads, h_dim, out_dim):
super().__init__()
self.attn = nn.MultiheadAttention(embed_dim, num_heads)
self.attn = nn.MultiheadAttention(embed_dim, num_heads, dropout=0.1)
self.layer_norm = nn.LayerNorm(embed_dim)
self.linear = nn.Linear(embed_dim, out_dim)
def __call__(self, X):
attn_out, attn_weights = self.attn(query=X, key=X, value=X)
out = self.linear(attn_out)
out, attn_weights = self.attn(query=X, key=X, value=X)
out = self.layer_norm(out)
out = self.linear(out)
# out = self.sigmoid(out)
return out
# out = self.relu(out)
# out = self.linear2(out)
# out = self.sigmoid(out)
def transform(self, X):
attn_out, attn_weights = self.attn(query=X, key=X, value=X)
return attn_out
return self.__call__(X)
# out, attn_weights = self.attn(query=X, key=X, value=X)
# out = self.layer_norm(out)
# out = self.linear(out)
# out = self.sigmoid(out)
# return out
# out = self.relu(out)
# out = self.linear2(out)
# out = self.sigmoid(out)
def save_pretrained(self, path):
torch.save(self.state_dict(), f"{path}.pt")
torch.save(self, f"{path}.pt")
# torch.save(self.state_dict(), f"{path}.pt")
def _wtf(self):
print("wtf")
def from_pretrained(self, path):
return torch.load(f"{path}.pt")
class AttentionAggregator:
def __init__(self, embed_dim, out_dim, epochs, num_heads=1, device="cpu"):
def __init__(
self,
embed_dim,
out_dim,
epochs,
lr,
patience,
attn_stacking_type,
h_dim=512,
num_heads=1,
device="cpu",
):
self.embed_dim = embed_dim
self.h_dim = h_dim
self.out_dim = out_dim
self.patience = patience
self.num_heads = num_heads
self.device = device
self.epochs = epochs
self.attn = AttentionModule(embed_dim, num_heads, out_dim).to(self.device)
self.lr = lr
self.stacking_type = attn_stacking_type
self.tr_batch_size = 512
self.eval_batch_size = 1024
self.attn = AttentionModule(
self.embed_dim, self.num_heads, self.h_dim, self.out_dim
).to(self.device)
def fit(self, X, Y):
print("- fitting Attention-based aggregating function")
hstacked_X = self.stack(X)
dataset = AggregatorDatasetTorch(hstacked_X, Y)
tra_dataloader = DataLoader(dataset, batch_size=32, shuffle=True)
tr_lX, tr_lY, val_lX, val_lY = self.get_train_val_data(
hstacked_X, Y, split=0.2, seed=42
)
tra_dataloader = DataLoader(
AggregatorDatasetTorch(tr_lX, tr_lY, split="train"),
batch_size=self.tr_batch_size,
shuffle=True,
)
eval_dataloader = DataLoader(
AggregatorDatasetTorch(val_lX, val_lY, split="eval"),
batch_size=self.eval_batch_size,
shuffle=False,
)
experiment_name = "attention_aggregator"
trainer = Trainer(
self.attn,
optimizer_name="adamW",
lr=1e-3,
lr=self.lr,
loss_fn=torch.nn.CrossEntropyLoss(),
print_steps=100,
evaluate_step=1000,
patience=10,
print_steps=25,
evaluate_step=50,
patience=self.patience,
experiment_name=experiment_name,
device=self.device,
checkpoint_path="models/aggregator",
@ -309,15 +354,14 @@ class AttentionAggregator:
trainer.train(
train_dataloader=tra_dataloader,
eval_dataloader=tra_dataloader,
eval_dataloader=eval_dataloader,
epochs=self.epochs,
)
return self
def transform(self, X):
# TODO: implement transform
h_stacked = self.stack(X)
dataset = AggregatorDatasetTorch(h_stacked, lY=None, split="whole")
hstacked_X = self.stack(X)
dataset = AggregatorDatasetTorch(hstacked_X, lY=None, split="whole")
dataloader = DataLoader(dataset, batch_size=32, shuffle=False)
_embeds = []
@ -339,10 +383,13 @@ class AttentionAggregator:
return l_embeds
def stack(self, data):
hstack = self._hstack(data)
if self.stacking_type == "concat":
hstack = self._concat_stack(data)
elif self.stacking_type == "mean":
hstack = self._mean_stack(data)
return hstack
def _hstack(self, data):
def _concat_stack(self, data):
_langs = data[0].keys()
l_projections = {}
for l in _langs:
@ -351,8 +398,31 @@ class AttentionAggregator:
)
return l_projections
def _vstack(self, data):
return torch.vstack()
def _mean_stack(self, data):
# TODO: double check this mess
aggregated = {lang: torch.zeros(d.shape) for lang, d in data[0].items()}
for lang_projections in data:
for lang, projection in lang_projections.items():
aggregated[lang] += projection
for lang, projection in aggregated.items():
aggregated[lang] = (aggregated[lang] / len(data)).float()
return aggregated
def get_train_val_data(self, lX, lY, split=0.2, seed=42):
tr_lX, tr_lY, val_lX, val_lY = {}, {}, {}, {}
for lang in lX.keys():
tr_X, val_X, tr_Y, val_Y = train_test_split(
lX[lang], lY[lang], test_size=split, random_state=seed, shuffle=False
)
tr_lX[lang] = tr_X
tr_lY[lang] = tr_Y
val_lX[lang] = val_X
val_lY[lang] = val_Y
return tr_lX, tr_lY, val_lX, val_lY
class AggregatorDatasetTorch(Dataset):

View File

@ -16,7 +16,6 @@ transformers.logging.set_verbosity_error()
class VisualTransformerGen(ViewGen, TransformerGen):
# TODO: probabilistic behaviour
def __init__(
self,
model_name,

41
main.py
View File

@ -13,18 +13,23 @@ from gfun.generalizedFunnelling import GeneralizedFunnelling
TODO:
- add documentations sphinx
- zero-shot setup
- set probabilistic behaviour in Transformer parent-class
- pooling / attention aggregation
- load pre-trained VGFs while retaining ability to train new ones (self.fitted = True in loaded? or smt like that)
- test split in MultiNews dataset
- when we load a model and change its config (eg change the agg func, re-train meta), we should store this model as a new one (save it)
"""
def get_dataset(datasetname):
assert datasetname in ["multinews", "amazon", "rcv1-2"], "dataset not supported"
RCV_DATAPATH = expanduser(
"~/datasets/rcv1-2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle"
)
JRC_DATAPATH = expanduser(
"~/datasets/jrc/jrc_doclist_1958-2005vs2006_all_top300_noparallel_processed_run0.pickle"
)
MULTINEWS_DATAPATH = expanduser("~/datasets/MultiNews/20110730/")
if datasetname == "multinews":
dataset = MultiNewsDataset(
expanduser(MULTINEWS_DATAPATH),
@ -38,11 +43,9 @@ def get_dataset(datasetname):
max_labels=args.max_labels,
)
elif datasetname == "rcv1-2":
dataset = (
MultilingualDataset(dataset_name="rcv1-2")
.load(RCV_DATAPATH)
.reduce_data(langs=["en", "it", "fr"], maxn=args.nrows)
)
dataset = MultilingualDataset(dataset_name="rcv1-2").load(RCV_DATAPATH)
if args.nrows is not None:
dataset.reduce_data(langs=["en", "it", "fr"], maxn=args.nrows)
else:
raise NotImplementedError
return dataset
@ -55,11 +58,9 @@ def main(args):
):
lX, lY = dataset.training()
lX_te, lY_te = dataset.test()
# print("[NB: for debug purposes, training set is also used as test set]\n")
# lX_te, lY_te = dataset.training()
else:
_lX = dataset.dX
_lY = dataset.dY
lX = dataset.dX
lY = dataset.dY
tinit = time()
@ -78,6 +79,7 @@ def main(args):
# dataset params ----------------------
dataset_name=args.dataset,
langs=dataset.langs(),
num_labels=dataset.num_labels(),
# Posterior VGF params ----------------
posterior=args.posteriors,
# Multilingual VGF params -------------
@ -100,22 +102,20 @@ def main(args):
aggfunc=args.aggfunc,
optimc=args.optimc,
load_trained=args.load_trained,
load_meta=args.meta,
n_jobs=args.n_jobs,
)
# gfun.get_config()
gfun.fit(lX, lY)
if args.load_trained is not None:
gfun.save()
# if not args.load_model:
# gfun.save()
if args.load_trained is None:
gfun.save(save_first_tier=True, save_meta=True)
preds = gfun.transform(lX)
train_eval = evaluate(lY, preds)
log_eval(train_eval, phase="train")
# train_eval = evaluate(lY, preds)
# log_eval(train_eval, phase="train")
timetr = time()
print(f"- training completed in {timetr - tinit:.2f} seconds")
@ -130,10 +130,11 @@ def main(args):
if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument("-l", "--load_trained", type=str, default=None)
parser.add_argument("--meta", action="store_true")
# Dataset parameters -------------------
parser.add_argument("-d", "--dataset", type=str, default="multinews")
parser.add_argument("--domains", type=str, default="all")
parser.add_argument("--nrows", type=int, default=100)
parser.add_argument("--nrows", type=int, default=None)
parser.add_argument("--min_count", type=int, default=10)
parser.add_argument("--max_labels", type=int, default=50)
# gFUN parameters ----------------------
@ -148,7 +149,7 @@ if __name__ == "__main__":
# transformer parameters ---------------
parser.add_argument("--transformer_name", type=str, default="mbert")
parser.add_argument("--batch_size", type=int, default=32)
parser.add_argument("--epochs", type=int, default=10)
parser.add_argument("--epochs", type=int, default=1000)
parser.add_argument("--lr", type=float, default=1e-5)
parser.add_argument("--max_length", type=int, default=512)
parser.add_argument("--patience", type=int, default=5)