import pickle from argparse import ArgumentParser from os.path import expanduser from time import time from dataManager.amazonDataset import AmazonDataset from dataManager.multilingualDataset import MultilingualDataset from dataManager.multiNewsDataset import MultiNewsDataset from dataManager.glamiDataset import GlamiDataset from dataManager.gFunDataset import gFunDataset from evaluation.evaluate import evaluate, log_eval from gfun.generalizedFunnelling import GeneralizedFunnelling """ TODO: - [!] add support for Binary Datasets (e.g. cls) - NB: CLS dataset is loading only "books" domain data - [!] documents should be trimmed to the same length (?) - [!] logging - add documentations sphinx - [!] zero-shot setup - FFNN posterior-probabilities' dependent - re-init langs when loading VGFs? - [!] loss of Attention-aggregator seems to be uncorrelated with Macro-F1 on the validation set! - [!] experiment with weight init of Attention-aggregator """ def get_dataset(datasetname, args): assert datasetname in [ "multinews", "amazon", "rcv1-2", "glami", "cls", ], "dataset not supported" RCV_DATAPATH = expanduser( "~/datasets/rcv1-2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle" ) JRC_DATAPATH = expanduser( "~/datasets/jrc/jrc_doclist_1958-2005vs2006_all_top300_noparallel_processed_run0.pickle" ) CLS_DATAPATH = expanduser("~/datasets/cls-acl10-processed/cls-acl10-processed.pkl") MULTINEWS_DATAPATH = expanduser("~/datasets/MultiNews/20110730/") GLAMI_DATAPATH = expanduser("~/datasets/GLAMI-1M-dataset") if datasetname == "multinews": # TODO: convert to gFunDataset raise NotImplementedError dataset = MultiNewsDataset( expanduser(MULTINEWS_DATAPATH), excluded_langs=["ar", "pe", "pl", "tr", "ua"], ) elif datasetname == "amazon": # TODO: convert to gFunDataset raise NotImplementedError dataset = AmazonDataset( domains=args.domains, nrows=args.nrows, min_count=args.min_count, max_labels=args.max_labels, ) elif datasetname == "rcv1-2": dataset = gFunDataset( dataset_dir=RCV_DATAPATH, is_textual=True, is_visual=False, is_multilabel=True, nrows=args.nrows, ) elif datasetname == "glami": dataset = gFunDataset( dataset_dir=GLAMI_DATAPATH, is_textual=True, is_visual=True, is_multilabel=False, nrows=args.nrows, ) elif datasetname == "cls": dataset = gFunDataset( dataset_dir=CLS_DATAPATH, is_textual=True, is_visual=False, is_multilabel=False, nrows=args.nrows, ) else: raise NotImplementedError return dataset def main(args): dataset = get_dataset(args.dataset, args) if ( isinstance(dataset, MultilingualDataset) or isinstance(dataset, MultiNewsDataset) or isinstance(dataset, GlamiDataset) or isinstance(dataset, gFunDataset) ): lX, lY = dataset.training() lX_te, lY_te = dataset.test() else: lX = dataset.dX lY = dataset.dY tinit = time() if args.load_trained is None: assert any( [ args.posteriors, args.wce, args.multilingual, args.multilingual, args.textual_transformer, args.visual_transformer, ] ), "At least one of VGF must be True" gfun = GeneralizedFunnelling( # dataset params ---------------------- dataset_name=args.dataset, langs=dataset.langs(), num_labels=dataset.num_labels(), # Posterior VGF params ---------------- posterior=args.posteriors, # Multilingual VGF params ------------- multilingual=args.multilingual, embed_dir="~/resources/muse_embeddings", # WCE VGF params ---------------------- wce=args.wce, # Transformer VGF params -------------- textual_transformer=args.textual_transformer, textual_transformer_name=args.transformer_name, batch_size=args.batch_size, epochs=args.epochs, lr=args.lr, max_length=args.max_length, patience=args.patience, evaluate_step=args.evaluate_step, device="cuda", # Visual Transformer VGF params -------------- visual_transformer=args.visual_transformer, visual_transformer_name=args.visual_transformer_name, # batch_size=args.batch_size, # epochs=args.epochs, # lr=args.lr, # patience=args.patience, # evaluate_step=args.evaluate_step, # device="cuda", # General params ---------------------- probabilistic=args.features, aggfunc=args.aggfunc, optimc=args.optimc, load_trained=args.load_trained, load_meta=args.meta, n_jobs=args.n_jobs, ) # gfun.get_config() gfun.fit(lX, lY) if args.load_trained is None and not args.nosave: gfun.save(save_first_tier=True, save_meta=True) # print("- Computing evaluation on training set") # preds = gfun.transform(lX) # train_eval = evaluate(lY, preds) # log_eval(train_eval, phase="train") timetr = time() print(f"- training completed in {timetr - tinit:.2f} seconds") gfun_preds = gfun.transform(lX_te) test_eval = evaluate(lY_te, gfun_preds) log_eval(test_eval, phase="test") timeval = time() print(f"- testing completed in {timeval - timetr:.2f} seconds") if __name__ == "__main__": parser = ArgumentParser() parser.add_argument("-l", "--load_trained", type=str, default=None) parser.add_argument("--meta", action="store_true") parser.add_argument("--nosave", action="store_true") # Dataset parameters ------------------- parser.add_argument("-d", "--dataset", type=str, default="rcv1-2") parser.add_argument("--domains", type=str, default="all") parser.add_argument("--nrows", type=int, default=None) parser.add_argument("--min_count", type=int, default=10) parser.add_argument("--max_labels", type=int, default=50) # gFUN parameters ---------------------- parser.add_argument("-p", "--posteriors", action="store_true") parser.add_argument("-m", "--multilingual", action="store_true") parser.add_argument("-w", "--wce", action="store_true") parser.add_argument("-t", "--textual_transformer", action="store_true") parser.add_argument("-v", "--visual_transformer", action="store_true") parser.add_argument("--n_jobs", type=int, default=-1) parser.add_argument("--optimc", action="store_true") parser.add_argument("--features", action="store_false") parser.add_argument("--aggfunc", type=str, default="mean") # transformer parameters --------------- parser.add_argument("--transformer_name", type=str, default="mbert") parser.add_argument("--batch_size", type=int, default=32) parser.add_argument("--epochs", type=int, default=100) parser.add_argument("--lr", type=float, default=1e-5) parser.add_argument("--max_length", type=int, default=128) parser.add_argument("--patience", type=int, default=5) parser.add_argument("--evaluate_step", type=int, default=10) # Visual Transformer parameters -------------- parser.add_argument("--visual_transformer_name", type=str, default="vit") args = parser.parse_args() main(args)