log average metrics via wandb
This commit is contained in:
parent
5ef0904e0e
commit
7d0d6ba1f6
|
@ -38,6 +38,13 @@ def format_langkey_wandb(lang_dict):
|
||||||
return log_dict
|
return log_dict
|
||||||
|
|
||||||
|
|
||||||
|
def format_average_wandb(avg_dict):
|
||||||
|
log_dict = {}
|
||||||
|
for metric, value in avg_dict.items():
|
||||||
|
log_dict[f"average metric/{metric}"] = value
|
||||||
|
return log_dict
|
||||||
|
|
||||||
|
|
||||||
def XdotM(X, M, sif):
|
def XdotM(X, M, sif):
|
||||||
E = X.dot(M)
|
E = X.dot(M)
|
||||||
if sif:
|
if sif:
|
||||||
|
@ -221,7 +228,7 @@ class Trainer:
|
||||||
{
|
{
|
||||||
"loss/val": eval_loss,
|
"loss/val": eval_loss,
|
||||||
**format_langkey_wandb(lang_metrics),
|
**format_langkey_wandb(lang_metrics),
|
||||||
"average/metrics": avg_metrics,
|
**format_average_wandb(avg_metrics),
|
||||||
},
|
},
|
||||||
commit=False,
|
commit=False,
|
||||||
)
|
)
|
||||||
|
|
4
main.py
4
main.py
|
@ -1,6 +1,6 @@
|
||||||
import os
|
import os
|
||||||
|
|
||||||
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
|
os.environ["CUDA_VISIBLE_DEVICES"] = "3"
|
||||||
|
|
||||||
from argparse import ArgumentParser
|
from argparse import ArgumentParser
|
||||||
from time import time
|
from time import time
|
||||||
|
@ -11,7 +11,7 @@ from gfun.generalizedFunnelling import GeneralizedFunnelling
|
||||||
|
|
||||||
"""
|
"""
|
||||||
TODO:
|
TODO:
|
||||||
- [!] LR scheduler
|
- [!] add support for mT5
|
||||||
- [!] CLS dataset is loading only "books" domain data
|
- [!] CLS dataset is loading only "books" domain data
|
||||||
- [!] documents should be trimmed to the same length (?)
|
- [!] documents should be trimmed to the same length (?)
|
||||||
- [!] overall gfun results logger
|
- [!] overall gfun results logger
|
||||||
|
|
Loading…
Reference in New Issue