gfun_multimodal/main.py

218 lines
7.4 KiB
Python

import pickle
from argparse import ArgumentParser
from os.path import expanduser
from time import time
from dataManager.amazonDataset import AmazonDataset
from dataManager.multilingualDataset import MultilingualDataset
from dataManager.multiNewsDataset import MultiNewsDataset
from dataManager.glamiDataset import GlamiDataset
from dataManager.gFunDataset import gFunDataset
from evaluation.evaluate import evaluate, log_eval
from gfun.generalizedFunnelling import GeneralizedFunnelling
"""
TODO:
- [!] add support for Binary Datasets (e.g. cls)
- [!] logging
- add documentations sphinx
- [!] zero-shot setup
- FFNN posterior-probabilities' dependent
- re-init langs when loading VGFs?
- [!] loss of Attention-aggregator seems to be uncorrelated with Macro-F1 on the validation set!
- [!] experiment with weight init of Attention-aggregator
"""
def get_dataset(datasetname, args):
assert datasetname in [
"multinews",
"amazon",
"rcv1-2",
"glami",
"cls",
], "dataset not supported"
RCV_DATAPATH = expanduser(
"~/datasets/rcv1-2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle"
)
JRC_DATAPATH = expanduser(
"~/datasets/jrc/jrc_doclist_1958-2005vs2006_all_top300_noparallel_processed_run0.pickle"
)
CLS_DATAPATH = expanduser("~/datasets/cls-acl10-processed/cls-acl10-processed.pkl")
MULTINEWS_DATAPATH = expanduser("~/datasets/MultiNews/20110730/")
GLAMI_DATAPATH = expanduser("~/datasets/GLAMI-1M-dataset")
if datasetname == "multinews":
# TODO: convert to gFunDataset
raise NotImplementedError
dataset = MultiNewsDataset(
expanduser(MULTINEWS_DATAPATH),
excluded_langs=["ar", "pe", "pl", "tr", "ua"],
)
elif datasetname == "amazon":
# TODO: convert to gFunDataset
raise NotImplementedError
dataset = AmazonDataset(
domains=args.domains,
nrows=args.nrows,
min_count=args.min_count,
max_labels=args.max_labels,
)
elif datasetname == "rcv1-2":
dataset = gFunDataset(
dataset_dir=RCV_DATAPATH,
is_textual=True,
is_visual=False,
is_multilabel=True,
nrows=args.nrows,
)
elif datasetname == "glami":
dataset = gFunDataset(
dataset_dir=GLAMI_DATAPATH,
is_textual=True,
is_visual=True,
is_multilabel=False,
nrows=args.nrows,
)
elif datasetname == "cls":
dataset = gFunDataset(
dataset_dir=CLS_DATAPATH,
is_textual=True,
is_visual=False,
is_multilabel=False,
nrows=args.nrows,
)
else:
raise NotImplementedError
return dataset
def main(args):
dataset = get_dataset(args.dataset, args)
if (
isinstance(dataset, MultilingualDataset)
or isinstance(dataset, MultiNewsDataset)
or isinstance(dataset, GlamiDataset)
or isinstance(dataset, gFunDataset)
):
lX, lY = dataset.training()
lX_te, lY_te = dataset.test()
else:
lX = dataset.dX
lY = dataset.dY
tinit = time()
if args.load_trained is None:
assert any(
[
args.posteriors,
args.wce,
args.multilingual,
args.multilingual,
args.textual_transformer,
args.visual_transformer,
]
), "At least one of VGF must be True"
gfun = GeneralizedFunnelling(
# dataset params ----------------------
dataset_name=args.dataset,
langs=dataset.langs(),
num_labels=dataset.num_labels(),
# Posterior VGF params ----------------
posterior=args.posteriors,
# Multilingual VGF params -------------
multilingual=args.multilingual,
embed_dir="~/resources/muse_embeddings",
# WCE VGF params ----------------------
wce=args.wce,
# Transformer VGF params --------------
textual_transformer=args.textual_transformer,
textual_transformer_name=args.transformer_name,
batch_size=args.batch_size,
epochs=args.epochs,
lr=args.lr,
max_length=args.max_length,
patience=args.patience,
evaluate_step=args.evaluate_step,
device="cuda",
# Visual Transformer VGF params --------------
visual_transformer=args.visual_transformer,
visual_transformer_name=args.visual_transformer_name,
# batch_size=args.batch_size,
# epochs=args.epochs,
# lr=args.lr,
# patience=args.patience,
# evaluate_step=args.evaluate_step,
# device="cuda",
# General params ----------------------
probabilistic=args.features,
aggfunc=args.aggfunc,
optimc=args.optimc,
load_trained=args.load_trained,
load_meta=args.meta,
n_jobs=args.n_jobs,
)
# gfun.get_config()
gfun.fit(lX, lY)
if args.load_trained is None and not args.nosave:
gfun.save(save_first_tier=True, save_meta=True)
# print("- Computing evaluation on training set")
# preds = gfun.transform(lX)
# train_eval = evaluate(lY, preds)
# log_eval(train_eval, phase="train")
timetr = time()
print(f"- training completed in {timetr - tinit:.2f} seconds")
gfun_preds = gfun.transform(lX_te)
test_eval = evaluate(lY_te, gfun_preds)
log_eval(test_eval, phase="test")
timeval = time()
print(f"- testing completed in {timeval - timetr:.2f} seconds")
if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument("-l", "--load_trained", type=str, default=None)
parser.add_argument("--meta", action="store_true")
parser.add_argument("--nosave", action="store_true")
# Dataset parameters -------------------
parser.add_argument("-d", "--dataset", type=str, default="rcv1-2")
parser.add_argument("--domains", type=str, default="all")
parser.add_argument("--nrows", type=int, default=None)
parser.add_argument("--min_count", type=int, default=10)
parser.add_argument("--max_labels", type=int, default=50)
# gFUN parameters ----------------------
parser.add_argument("-p", "--posteriors", action="store_true")
parser.add_argument("-m", "--multilingual", action="store_true")
parser.add_argument("-w", "--wce", action="store_true")
parser.add_argument("-t", "--textual_transformer", action="store_true")
parser.add_argument("-v", "--visual_transformer", action="store_true")
parser.add_argument("--n_jobs", type=int, default=-1)
parser.add_argument("--optimc", action="store_true")
parser.add_argument("--features", action="store_false")
parser.add_argument("--aggfunc", type=str, default="mean")
# transformer parameters ---------------
parser.add_argument("--transformer_name", type=str, default="mbert")
parser.add_argument("--batch_size", type=int, default=32)
parser.add_argument("--epochs", type=int, default=100)
parser.add_argument("--lr", type=float, default=1e-5)
parser.add_argument("--max_length", type=int, default=128)
parser.add_argument("--patience", type=int, default=5)
parser.add_argument("--evaluate_step", type=int, default=10)
# Visual Transformer parameters --------------
parser.add_argument("--visual_transformer_name", type=str, default="vit")
args = parser.parse_args()
main(args)