getter for gFun and VGFs config

This commit is contained in:
Andrea Pedrotti 2023-03-16 11:41:40 +01:00
parent 9d43ebb23b
commit 17d0003e48
5 changed files with 42 additions and 13 deletions

View File

@ -309,15 +309,21 @@ class GeneralizedFunnelling:
return aggregated
def get_config(self):
print("\n")
print("-" * 50)
print("[GeneralizedFunnelling config]")
print(f"- model trained on langs: {self.langs}")
print("-- View Generating Functions configurations:\n")
c = {}
for vgf in self.first_tier_learners:
print(vgf)
print("-" * 50)
vgf_config = vgf.get_config()
c.update(vgf_config)
gfun_config = {
"id": self._model_id,
"aggfunc": self.aggfunc,
"optimc": self.optimc,
"dataset": self.dataset_name,
}
c["gFun"] = gfun_config
return c
def save(self, save_first_tier=True, save_meta=True):
print(f"- Saving GeneralizedFunnelling model with id: {self._model_id}")

View File

@ -277,3 +277,7 @@ class TextualTransformerGen(ViewGen, TransformerGen):
def __str__(self):
str = f"[Transformer VGF (t)]\n- model_name: {self.model_name}\n- max_length: {self.max_length}\n- batch_size: {self.batch_size}\n- batch_size_eval: {self.batch_size_eval}\n- lr: {self.lr}\n- epochs: {self.epochs}\n- device: {self.device}\n- print_steps: {self.print_steps}\n- evaluate_step: {self.evaluate_step}\n- patience: {self.patience}\n- probabilistic: {self.probabilistic}\n"
return str
def get_config(self):
c = super().get_config()
return {"textual_trf": c}

View File

@ -94,3 +94,21 @@ class TransformerGen:
val_lY[lang] = val_Y
return tr_lX, tr_lY, val_lX, val_lY
def get_config(self):
return {
"model_name": self.model_name,
"dataset_name": self.dataset_name,
"epochs": self.epochs,
"lr": self.lr,
"batch_size": self.batch_size,
"batch_size_eval": self.batch_size_eval,
"max_length": self.max_length,
"print_steps": self.print_steps,
"device": self.device,
"probabilistic": self.probabilistic,
"n_jobs": self.n_jobs,
"evaluate_step": self.evaluate_step,
"verbose": self.verbose,
"patience": self.patience,
}

View File

@ -186,3 +186,6 @@ class VisualTransformerGen(ViewGen, TransformerGen):
def __str__(self):
str = f"[Visual Transformer VGF (v)]\n- model_name: {self.model_name}\n- batch_size: {self.batch_size}\n- batch_size_eval: {self.batch_size_eval}\n- lr: {self.lr}\n- epochs: {self.epochs}\n- device: {self.device}\n- print_steps: {self.print_steps}\n- evaluate_step: {self.evaluate_step}\n- patience: {self.patience}\n- probabilistic: {self.probabilistic}\n"
return str
def get_config(self):
return {"visual_trf": super().get_config()}

10
main.py
View File

@ -13,13 +13,10 @@ from gfun.generalizedFunnelling import GeneralizedFunnelling
"""
TODO:
- Transformers VGFs:
- save/load for MT5ForSqeuenceClassification
- freeze params method
- log on step rather than epoch?
- General:
[!] zero-shot setup
- CLS dataset is loading only "books" domain data
- log on wandb also the other VGF results + final results
- documents should be trimmed to the same length (for SVMs we are using way too long tokens)
- Attention Aggregator:
- experiment with weight init of Attention-aggregator
@ -106,9 +103,10 @@ def main(args):
n_jobs=args.n_jobs,
)
wandb.init(
project="gfun", name=f"gFun-{get_config_name(args)}"
) # TODO: Add config to log
config = gfun.get_config()
wandb.init(project="gfun", name=f"gFun-{get_config_name(args)}", config=config)
gfun.fit(lX, lY)
if args.load_trained is None and not args.nosave: