getter for gFun and VGFs config
This commit is contained in:
parent
9d43ebb23b
commit
17d0003e48
|
@ -309,15 +309,21 @@ class GeneralizedFunnelling:
|
|||
return aggregated
|
||||
|
||||
def get_config(self):
|
||||
print("\n")
|
||||
print("-" * 50)
|
||||
print("[GeneralizedFunnelling config]")
|
||||
print(f"- model trained on langs: {self.langs}")
|
||||
print("-- View Generating Functions configurations:\n")
|
||||
c = {}
|
||||
|
||||
for vgf in self.first_tier_learners:
|
||||
print(vgf)
|
||||
print("-" * 50)
|
||||
vgf_config = vgf.get_config()
|
||||
c.update(vgf_config)
|
||||
|
||||
gfun_config = {
|
||||
"id": self._model_id,
|
||||
"aggfunc": self.aggfunc,
|
||||
"optimc": self.optimc,
|
||||
"dataset": self.dataset_name,
|
||||
}
|
||||
|
||||
c["gFun"] = gfun_config
|
||||
return c
|
||||
|
||||
def save(self, save_first_tier=True, save_meta=True):
|
||||
print(f"- Saving GeneralizedFunnelling model with id: {self._model_id}")
|
||||
|
|
|
@ -277,3 +277,7 @@ class TextualTransformerGen(ViewGen, TransformerGen):
|
|||
def __str__(self):
|
||||
str = f"[Transformer VGF (t)]\n- model_name: {self.model_name}\n- max_length: {self.max_length}\n- batch_size: {self.batch_size}\n- batch_size_eval: {self.batch_size_eval}\n- lr: {self.lr}\n- epochs: {self.epochs}\n- device: {self.device}\n- print_steps: {self.print_steps}\n- evaluate_step: {self.evaluate_step}\n- patience: {self.patience}\n- probabilistic: {self.probabilistic}\n"
|
||||
return str
|
||||
|
||||
def get_config(self):
|
||||
c = super().get_config()
|
||||
return {"textual_trf": c}
|
||||
|
|
|
@ -94,3 +94,21 @@ class TransformerGen:
|
|||
val_lY[lang] = val_Y
|
||||
|
||||
return tr_lX, tr_lY, val_lX, val_lY
|
||||
|
||||
def get_config(self):
|
||||
return {
|
||||
"model_name": self.model_name,
|
||||
"dataset_name": self.dataset_name,
|
||||
"epochs": self.epochs,
|
||||
"lr": self.lr,
|
||||
"batch_size": self.batch_size,
|
||||
"batch_size_eval": self.batch_size_eval,
|
||||
"max_length": self.max_length,
|
||||
"print_steps": self.print_steps,
|
||||
"device": self.device,
|
||||
"probabilistic": self.probabilistic,
|
||||
"n_jobs": self.n_jobs,
|
||||
"evaluate_step": self.evaluate_step,
|
||||
"verbose": self.verbose,
|
||||
"patience": self.patience,
|
||||
}
|
|
@ -186,3 +186,6 @@ class VisualTransformerGen(ViewGen, TransformerGen):
|
|||
def __str__(self):
|
||||
str = f"[Visual Transformer VGF (v)]\n- model_name: {self.model_name}\n- batch_size: {self.batch_size}\n- batch_size_eval: {self.batch_size_eval}\n- lr: {self.lr}\n- epochs: {self.epochs}\n- device: {self.device}\n- print_steps: {self.print_steps}\n- evaluate_step: {self.evaluate_step}\n- patience: {self.patience}\n- probabilistic: {self.probabilistic}\n"
|
||||
return str
|
||||
|
||||
def get_config(self):
|
||||
return {"visual_trf": super().get_config()}
|
||||
|
|
10
main.py
10
main.py
|
@ -13,13 +13,10 @@ from gfun.generalizedFunnelling import GeneralizedFunnelling
|
|||
"""
|
||||
TODO:
|
||||
- Transformers VGFs:
|
||||
- save/load for MT5ForSqeuenceClassification
|
||||
- freeze params method
|
||||
- log on step rather than epoch?
|
||||
- General:
|
||||
[!] zero-shot setup
|
||||
- CLS dataset is loading only "books" domain data
|
||||
- log on wandb also the other VGF results + final results
|
||||
- documents should be trimmed to the same length (for SVMs we are using way too long tokens)
|
||||
- Attention Aggregator:
|
||||
- experiment with weight init of Attention-aggregator
|
||||
|
@ -106,9 +103,10 @@ def main(args):
|
|||
n_jobs=args.n_jobs,
|
||||
)
|
||||
|
||||
wandb.init(
|
||||
project="gfun", name=f"gFun-{get_config_name(args)}"
|
||||
) # TODO: Add config to log
|
||||
config = gfun.get_config()
|
||||
|
||||
wandb.init(project="gfun", name=f"gFun-{get_config_name(args)}", config=config)
|
||||
|
||||
gfun.fit(lX, lY)
|
||||
|
||||
if args.load_trained is None and not args.nosave:
|
||||
|
|
Loading…
Reference in New Issue