log average metrics via wandb

This commit is contained in:
andreapdr 2023-03-10 11:21:33 +01:00
parent 5ef0904e0e
commit 7d0d6ba1f6
2 changed files with 12 additions and 5 deletions

View File

@ -38,6 +38,13 @@ def format_langkey_wandb(lang_dict):
return log_dict
def format_average_wandb(avg_dict):
log_dict = {}
for metric, value in avg_dict.items():
log_dict[f"average metric/{metric}"] = value
return log_dict
def XdotM(X, M, sif):
E = X.dot(M)
if sif:
@ -221,7 +228,7 @@ class Trainer:
{
"loss/val": eval_loss,
**format_langkey_wandb(lang_metrics),
"average/metrics": avg_metrics,
**format_average_wandb(avg_metrics),
},
commit=False,
)

View File

@ -1,6 +1,6 @@
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
os.environ["CUDA_VISIBLE_DEVICES"] = "3"
from argparse import ArgumentParser
from time import time
@ -9,9 +9,9 @@ from dataManager.utils import get_dataset
from evaluation.evaluate import evaluate, log_eval
from gfun.generalizedFunnelling import GeneralizedFunnelling
"""
TODO:
- [!] LR scheduler
"""
TODO:
- [!] add support for mT5
- [!] CLS dataset is loading only "books" domain data
- [!] documents should be trimmed to the same length (?)
- [!] overall gfun results logger