log average metrics via wandb

This commit is contained in:
andreapdr 2023-03-10 11:21:33 +01:00
parent 5ef0904e0e
commit 7d0d6ba1f6
2 changed files with 12 additions and 5 deletions

View File

@ -38,6 +38,13 @@ def format_langkey_wandb(lang_dict):
return log_dict return log_dict
def format_average_wandb(avg_dict):
log_dict = {}
for metric, value in avg_dict.items():
log_dict[f"average metric/{metric}"] = value
return log_dict
def XdotM(X, M, sif): def XdotM(X, M, sif):
E = X.dot(M) E = X.dot(M)
if sif: if sif:
@ -221,7 +228,7 @@ class Trainer:
{ {
"loss/val": eval_loss, "loss/val": eval_loss,
**format_langkey_wandb(lang_metrics), **format_langkey_wandb(lang_metrics),
"average/metrics": avg_metrics, **format_average_wandb(avg_metrics),
}, },
commit=False, commit=False,
) )

View File

@ -1,6 +1,6 @@
import os import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0" os.environ["CUDA_VISIBLE_DEVICES"] = "3"
from argparse import ArgumentParser from argparse import ArgumentParser
from time import time from time import time
@ -9,9 +9,9 @@ from dataManager.utils import get_dataset
from evaluation.evaluate import evaluate, log_eval from evaluation.evaluate import evaluate, log_eval
from gfun.generalizedFunnelling import GeneralizedFunnelling from gfun.generalizedFunnelling import GeneralizedFunnelling
""" """
TODO: TODO:
- [!] LR scheduler - [!] add support for mT5
- [!] CLS dataset is loading only "books" domain data - [!] CLS dataset is loading only "books" domain data
- [!] documents should be trimmed to the same length (?) - [!] documents should be trimmed to the same length (?)
- [!] overall gfun results logger - [!] overall gfun results logger