diff --git a/gfun/vgfs/commons.py b/gfun/vgfs/commons.py index 9a317af..6c4d59c 100644 --- a/gfun/vgfs/commons.py +++ b/gfun/vgfs/commons.py @@ -38,6 +38,13 @@ def format_langkey_wandb(lang_dict): return log_dict +def format_average_wandb(avg_dict): + log_dict = {} + for metric, value in avg_dict.items(): + log_dict[f"average metric/{metric}"] = value + return log_dict + + def XdotM(X, M, sif): E = X.dot(M) if sif: @@ -221,7 +228,7 @@ class Trainer: { "loss/val": eval_loss, **format_langkey_wandb(lang_metrics), - "average/metrics": avg_metrics, + **format_average_wandb(avg_metrics), }, commit=False, ) diff --git a/main.py b/main.py index 9efce33..05ef6ee 100644 --- a/main.py +++ b/main.py @@ -1,6 +1,6 @@ import os -os.environ["CUDA_VISIBLE_DEVICES"] = "0" +os.environ["CUDA_VISIBLE_DEVICES"] = "3" from argparse import ArgumentParser from time import time @@ -9,9 +9,9 @@ from dataManager.utils import get_dataset from evaluation.evaluate import evaluate, log_eval from gfun.generalizedFunnelling import GeneralizedFunnelling -""" -TODO: - - [!] LR scheduler +""" +TODO: + - [!] add support for mT5 - [!] CLS dataset is loading only "books" domain data - [!] documents should be trimmed to the same length (?) - [!] overall gfun results logger