fixed bug: we were applying sigmoid function 2 times when training the Attention-based aggregator
This commit is contained in:
parent
fc98bc3924
commit
7041f7b651
|
@ -178,4 +178,5 @@ cython_debug/
|
|||
out/*
|
||||
amazon_cateogories.bu.txt
|
||||
models/*
|
||||
scripts/
|
||||
scripts/
|
||||
logger/*
|
|
@ -109,9 +109,7 @@ class GlamiDataset:
|
|||
def get_label_binarizer(self, labels):
|
||||
mlb = LabelBinarizer()
|
||||
mlb.fit(labels)
|
||||
print(
|
||||
f"- Label binarizer initialized with the following labels:\n{mlb.classes_}"
|
||||
)
|
||||
print(f"- Label binarizer initialized with {len(mlb.classes_)} labels")
|
||||
return mlb
|
||||
|
||||
def binarize_labels(self, labels):
|
||||
|
|
|
@ -23,8 +23,9 @@ def evaluate(ly_true, ly_pred, metrics=evaluation_metrics, n_jobs=-1):
|
|||
return {lang: evals[i] for i, lang in enumerate(langs)}
|
||||
|
||||
|
||||
def log_eval(l_eval, phase="training"):
|
||||
print(f"\n[Results {phase}]")
|
||||
def log_eval(l_eval, phase="training", verbose=True):
|
||||
if verbose:
|
||||
print(f"\n[Results {phase}]")
|
||||
metrics = []
|
||||
for lang in l_eval.keys():
|
||||
macrof1, microf1, macrok, microk = l_eval[lang]
|
||||
|
@ -32,9 +33,10 @@ def log_eval(l_eval, phase="training"):
|
|||
if phase != "validation":
|
||||
print(f"Lang {lang}: macro-F1 = {macrof1:.3f} micro-F1 = {microf1:.3f}")
|
||||
averages = np.mean(np.array(metrics), axis=0)
|
||||
print(
|
||||
"Averages: MF1, mF1, MK, mK",
|
||||
np.round(averages, 3),
|
||||
"\n",
|
||||
)
|
||||
if verbose:
|
||||
print(
|
||||
"Averages: MF1, mF1, MK, mK",
|
||||
np.round(averages, 3),
|
||||
"\n",
|
||||
)
|
||||
return averages
|
||||
|
|
|
@ -156,7 +156,7 @@ class GeneralizedFunnelling:
|
|||
if "attn" in self.aggfunc:
|
||||
attn_stacking = self.aggfunc.split("_")[1]
|
||||
self.attn_aggregator = AttentionAggregator(
|
||||
embed_dim=self.get_attn_agg_dim(),
|
||||
embed_dim=self.get_attn_agg_dim(attn_stacking_type=attn_stacking),
|
||||
out_dim=self.num_labels,
|
||||
lr=self.lr_transformer,
|
||||
patience=self.patience,
|
||||
|
@ -173,6 +173,7 @@ class GeneralizedFunnelling:
|
|||
)
|
||||
|
||||
self._model_id = get_unique_id(
|
||||
self.dataset_name,
|
||||
self.posteriors_vgf,
|
||||
self.multilingual_vgf,
|
||||
self.wce_vgf,
|
||||
|
@ -376,7 +377,7 @@ class GeneralizedFunnelling:
|
|||
def _load_transformer(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def get_attn_agg_dim(self, attn_stacking_type="concat"):
|
||||
def get_attn_agg_dim(self, attn_stacking_type):
|
||||
if self.probabilistic and "attn" not in self.aggfunc:
|
||||
return len(self.first_tier_learners) * self.num_labels
|
||||
elif self.probabilistic and "attn" in self.aggfunc:
|
||||
|
@ -398,11 +399,11 @@ def get_params(optimc=False):
|
|||
return [{"kernel": [kernel], "C": c_range, "gamma": ["auto"]}]
|
||||
|
||||
|
||||
def get_unique_id(posterior, multilingual, wce, transformer, aggfunc):
|
||||
def get_unique_id(dataset_name, posterior, multilingual, wce, transformer, aggfunc):
|
||||
from datetime import datetime
|
||||
|
||||
now = datetime.now().strftime("%y%m%d")
|
||||
model_id = ""
|
||||
model_id = dataset_name
|
||||
model_id += "p" if posterior else ""
|
||||
model_id += "m" if multilingual else ""
|
||||
model_id += "w" if wce else ""
|
||||
|
|
|
@ -126,7 +126,7 @@ class Trainer:
|
|||
self.earlystopping = EarlyStopping(
|
||||
patience=patience,
|
||||
checkpoint_path=checkpoint_path,
|
||||
verbose=True,
|
||||
verbose=False,
|
||||
experiment_name=experiment_name,
|
||||
)
|
||||
|
||||
|
@ -149,18 +149,19 @@ class Trainer:
|
|||
for epoch in range(epochs):
|
||||
self.train_epoch(train_dataloader, epoch)
|
||||
if (epoch + 1) % self.evaluate_steps == 0:
|
||||
metric_watcher = self.evaluate(eval_dataloader)
|
||||
print_eval = (epoch + 1) % 25 == 0
|
||||
metric_watcher = self.evaluate(eval_dataloader, print_eval=print_eval)
|
||||
stop = self.earlystopping(metric_watcher, self.model, epoch + 1)
|
||||
if stop:
|
||||
print(
|
||||
f"- restoring best model from epoch {self.earlystopping.best_epoch}"
|
||||
f"- restoring best model from epoch {self.earlystopping.best_epoch} with best metric: {self.earlystopping.best_score:3f}"
|
||||
)
|
||||
self.model = self.earlystopping.load_model(self.model).to(
|
||||
self.device
|
||||
)
|
||||
break
|
||||
self.train_epoch(eval_dataloader, epoch=epoch)
|
||||
print(f"\n- last swipe on eval set")
|
||||
self.train_epoch(eval_dataloader, epoch=0)
|
||||
self.earlystopping.save_model(self.model)
|
||||
return self.model
|
||||
|
||||
|
@ -180,7 +181,7 @@ class Trainer:
|
|||
print(f"Epoch: {epoch+1} Step: {b_idx+1} Loss: {loss:.4f}")
|
||||
return self
|
||||
|
||||
def evaluate(self, dataloader):
|
||||
def evaluate(self, dataloader, print_eval=True):
|
||||
self.model.eval()
|
||||
|
||||
lY = defaultdict(list)
|
||||
|
@ -204,7 +205,7 @@ class Trainer:
|
|||
lY_hat[lang] = np.vstack(lY_hat[lang])
|
||||
|
||||
l_eval = evaluate(lY, lY_hat)
|
||||
average_metrics = log_eval(l_eval, phase="validation")
|
||||
average_metrics = log_eval(l_eval, phase="validation", verbose=print_eval)
|
||||
return average_metrics[0] # macro-F1
|
||||
|
||||
|
||||
|
@ -228,21 +229,23 @@ class EarlyStopping:
|
|||
|
||||
def __call__(self, validation, model, epoch):
|
||||
if validation > self.best_score:
|
||||
print(
|
||||
f"- earlystopping: Validation score improved from {self.best_score:.3f} to {validation:.3f}"
|
||||
)
|
||||
if self.verbose:
|
||||
print(
|
||||
f"- earlystopping: Validation score improved from {self.best_score:.3f} to {validation:.3f}"
|
||||
)
|
||||
self.best_score = validation
|
||||
self.counter = 0
|
||||
self.best_epoch = epoch
|
||||
# print(f"- earlystopping: Saving best model from epoch {epoch}")
|
||||
self.save_model(model)
|
||||
elif validation < (self.best_score + self.min_delta):
|
||||
self.counter += 1
|
||||
print(
|
||||
f"- earlystopping: Validation score decreased from {self.best_score:.3f} to {validation:.3f}, current patience: {self.patience - self.counter}"
|
||||
)
|
||||
if self.verbose:
|
||||
print(
|
||||
f"- earlystopping: Validation score decreased from {self.best_score:.3f} to {validation:.3f}, current patience: {self.patience - self.counter}"
|
||||
)
|
||||
if self.counter >= self.patience:
|
||||
if self.verbose:
|
||||
print(f"- earlystopping: Early stopping at epoch {epoch}")
|
||||
print(f"- earlystopping: Early stopping at epoch {epoch}")
|
||||
return True
|
||||
|
||||
def save_model(self, model):
|
||||
|
@ -256,36 +259,35 @@ class EarlyStopping:
|
|||
|
||||
|
||||
class AttentionModule(nn.Module):
|
||||
def __init__(self, embed_dim, num_heads, h_dim, out_dim):
|
||||
def __init__(self, embed_dim, num_heads, h_dim, out_dim, aggfunc_type):
|
||||
"""We are calling sigmoid on the evaluation loop (Trainer.evaluate), so we
|
||||
are not applying explicitly here at training time. However, we should
|
||||
explcitly squash outputs through the sigmoid at inference (self.transform) (???)
|
||||
"""
|
||||
super().__init__()
|
||||
self.aggfunc = aggfunc_type
|
||||
self.attn = nn.MultiheadAttention(embed_dim, num_heads, dropout=0.1)
|
||||
self.layer_norm = nn.LayerNorm(embed_dim)
|
||||
self.linear = nn.Linear(embed_dim, out_dim)
|
||||
# self.layer_norm = nn.LayerNorm(embed_dim)
|
||||
if self.aggfunc == "concat":
|
||||
self.linear = nn.Linear(embed_dim, out_dim)
|
||||
self.sigmoid = nn.Sigmoid()
|
||||
|
||||
def __call__(self, X):
|
||||
out, attn_weights = self.attn(query=X, key=X, value=X)
|
||||
out = self.layer_norm(out)
|
||||
out = self.linear(out)
|
||||
# out = self.layer_norm(out)
|
||||
if self.aggfunc == "concat":
|
||||
out = self.linear(out)
|
||||
# out = self.sigmoid(out)
|
||||
return out
|
||||
# out = self.relu(out)
|
||||
# out = self.linear2(out)
|
||||
# out = self.sigmoid(out)
|
||||
|
||||
def transform(self, X):
|
||||
return self.__call__(X)
|
||||
# out, attn_weights = self.attn(query=X, key=X, value=X)
|
||||
# out = self.layer_norm(out)
|
||||
# out = self.linear(out)
|
||||
# out = self.sigmoid(out)
|
||||
# return out
|
||||
# out = self.relu(out)
|
||||
# out = self.linear2(out)
|
||||
# out = self.sigmoid(out)
|
||||
"""explicitly calling sigmoid at inference time"""
|
||||
out, attn_weights = self.attn(query=X, key=X, value=X)
|
||||
out = self.sigmoid(out)
|
||||
return out
|
||||
|
||||
def save_pretrained(self, path):
|
||||
torch.save(self, f"{path}.pt")
|
||||
# torch.save(self.state_dict(), f"{path}.pt")
|
||||
|
||||
def from_pretrained(self, path):
|
||||
return torch.load(f"{path}.pt")
|
||||
|
@ -316,7 +318,11 @@ class AttentionAggregator:
|
|||
self.tr_batch_size = 512
|
||||
self.eval_batch_size = 1024
|
||||
self.attn = AttentionModule(
|
||||
self.embed_dim, self.num_heads, self.h_dim, self.out_dim
|
||||
self.embed_dim,
|
||||
self.num_heads,
|
||||
self.h_dim,
|
||||
self.out_dim,
|
||||
aggfunc_type=self.stacking_type,
|
||||
).to(self.device)
|
||||
|
||||
def fit(self, X, Y):
|
||||
|
@ -345,7 +351,7 @@ class AttentionAggregator:
|
|||
lr=self.lr,
|
||||
loss_fn=torch.nn.CrossEntropyLoss(),
|
||||
print_steps=25,
|
||||
evaluate_step=50,
|
||||
evaluate_step=10,
|
||||
patience=self.patience,
|
||||
experiment_name=experiment_name,
|
||||
device=self.device,
|
||||
|
|
14
main.py
14
main.py
|
@ -17,6 +17,11 @@ TODO:
|
|||
- load pre-trained VGFs while retaining ability to train new ones (self.fitted = True in loaded? or smt like that)
|
||||
- test split in MultiNews dataset
|
||||
- when we load a model and change its config (eg change the agg func, re-train meta), we should store this model as a new one (save it)
|
||||
- FFNN posterior-probabilities' dependent
|
||||
- re-init langs when loading VGFs?
|
||||
- there is a mess about sigmoid in the Attention aggregator + and evaluation function (predict). We were applying sig() 2 times on the outputs (at pred and at eval)...
|
||||
- [!] loss of Attention-aggregator seems to be uncorrelated with Macro-F1 on the validation set!
|
||||
- aligner layer (suggestion by G.Puccetti)
|
||||
"""
|
||||
|
||||
|
||||
|
@ -125,15 +130,16 @@ def main(args):
|
|||
if args.load_trained is None and not args.nosave:
|
||||
gfun.save(save_first_tier=True, save_meta=True)
|
||||
|
||||
preds = gfun.transform(lX)
|
||||
|
||||
# print("- Computing evaluation on training set")
|
||||
# preds = gfun.transform(lX)
|
||||
# train_eval = evaluate(lY, preds)
|
||||
# log_eval(train_eval, phase="train")
|
||||
|
||||
timetr = time()
|
||||
print(f"- training completed in {timetr - tinit:.2f} seconds")
|
||||
|
||||
test_eval = evaluate(lY_te, gfun.transform(lX_te))
|
||||
gfun_preds = gfun.transform(lX_te)
|
||||
test_eval = evaluate(lY_te, gfun_preds)
|
||||
log_eval(test_eval, phase="test")
|
||||
|
||||
timeval = time()
|
||||
|
@ -156,7 +162,7 @@ if __name__ == "__main__":
|
|||
parser.add_argument("-m", "--multilingual", action="store_true")
|
||||
parser.add_argument("-w", "--wce", action="store_true")
|
||||
parser.add_argument("-t", "--transformer", action="store_true")
|
||||
parser.add_argument("--n_jobs", type=int, default=1)
|
||||
parser.add_argument("--n_jobs", type=int, default=-1)
|
||||
parser.add_argument("--optimc", action="store_true")
|
||||
parser.add_argument("--features", action="store_false")
|
||||
parser.add_argument("--aggfunc", type=str, default="mean")
|
||||
|
|
Loading…
Reference in New Issue