log average metrics via wandb
This commit is contained in:
parent
5ef0904e0e
commit
7d0d6ba1f6
|
@ -38,6 +38,13 @@ def format_langkey_wandb(lang_dict):
|
|||
return log_dict
|
||||
|
||||
|
||||
def format_average_wandb(avg_dict):
|
||||
log_dict = {}
|
||||
for metric, value in avg_dict.items():
|
||||
log_dict[f"average metric/{metric}"] = value
|
||||
return log_dict
|
||||
|
||||
|
||||
def XdotM(X, M, sif):
|
||||
E = X.dot(M)
|
||||
if sif:
|
||||
|
@ -221,7 +228,7 @@ class Trainer:
|
|||
{
|
||||
"loss/val": eval_loss,
|
||||
**format_langkey_wandb(lang_metrics),
|
||||
"average/metrics": avg_metrics,
|
||||
**format_average_wandb(avg_metrics),
|
||||
},
|
||||
commit=False,
|
||||
)
|
||||
|
|
8
main.py
8
main.py
|
@ -1,6 +1,6 @@
|
|||
import os
|
||||
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = "3"
|
||||
|
||||
from argparse import ArgumentParser
|
||||
from time import time
|
||||
|
@ -9,9 +9,9 @@ from dataManager.utils import get_dataset
|
|||
from evaluation.evaluate import evaluate, log_eval
|
||||
from gfun.generalizedFunnelling import GeneralizedFunnelling
|
||||
|
||||
"""
|
||||
TODO:
|
||||
- [!] LR scheduler
|
||||
"""
|
||||
TODO:
|
||||
- [!] add support for mT5
|
||||
- [!] CLS dataset is loading only "books" domain data
|
||||
- [!] documents should be trimmed to the same length (?)
|
||||
- [!] overall gfun results logger
|
||||
|
|
Loading…
Reference in New Issue