import os os.environ["CUDA_VISIBLE_DEVICES"] = "0" from argparse import ArgumentParser from time import time from dataManager.utils import get_dataset from evaluation.evaluate import evaluate, log_eval from gfun.generalizedFunnelling import GeneralizedFunnelling """ TODO: - [!] LR scheduler - [!] CLS dataset is loading only "books" domain data - [!] documents should be trimmed to the same length (?) - [!] overall gfun results logger - add documentations sphinx - [!] zero-shot setup - FFNN posterior-probabilities' dependent - re-init langs when loading VGFs? - [!] experiment with weight init of Attention-aggregator """ def main(args): dataset = get_dataset(args.dataset, args) lX, lY = dataset.training() lX_te, lY_te = dataset.test() tinit = time() if args.load_trained is None: assert any( [ args.posteriors, args.wce, args.multilingual, args.multilingual, args.textual_transformer, args.visual_transformer, ] ), "At least one of VGF must be True" gfun = GeneralizedFunnelling( # dataset params ---------------------- dataset_name=args.dataset, langs=dataset.langs(), num_labels=dataset.num_labels(), classification_type=args.clf_type, # Posterior VGF params ---------------- posterior=args.posteriors, # Multilingual VGF params ------------- multilingual=args.multilingual, embed_dir="~/resources/muse_embeddings", # WCE VGF params ---------------------- wce=args.wce, # Transformer VGF params -------------- textual_transformer=args.textual_transformer, textual_transformer_name=args.transformer_name, batch_size=args.batch_size, eval_batch_size=args.eval_batch_size, epochs=args.epochs, textual_lr=args.textual_lr, visual_lr=args.visual_lr, max_length=args.max_length, patience=args.patience, evaluate_step=args.evaluate_step, device=args.device, # Visual Transformer VGF params -------------- visual_transformer=args.visual_transformer, visual_transformer_name=args.visual_transformer_name, # batch_size=args.batch_size, # epochs=args.epochs, # lr=args.lr, # patience=args.patience, # evaluate_step=args.evaluate_step, # device="cuda", # General params ---------------------- probabilistic=args.features, aggfunc=args.aggfunc, optimc=args.optimc, load_trained=args.load_trained, load_meta=args.meta, n_jobs=args.n_jobs, ) # gfun.get_config() gfun.fit(lX, lY) if args.load_trained is None and not args.nosave: gfun.save(save_first_tier=True, save_meta=True) # print("- Computing evaluation on training set") # preds = gfun.transform(lX) # train_eval = evaluate(lY, preds) # log_eval(train_eval, phase="train") timetr = time() print(f"- training completed in {timetr - tinit:.2f} seconds") gfun_preds = gfun.transform(lX_te) test_eval = evaluate(lY_te, gfun_preds, clf_type=args.clf_type, n_jobs=args.n_jobs) log_eval(test_eval, phase="test", clf_type=args.clf_type) timeval = time() print(f"- testing completed in {timeval - timetr:.2f} seconds") if __name__ == "__main__": parser = ArgumentParser() parser.add_argument("-l", "--load_trained", type=str, default=None) parser.add_argument("--meta", action="store_true") parser.add_argument("--nosave", action="store_true") parser.add_argument("--device", type=str, default="cuda") # Dataset parameters ------------------- parser.add_argument("-d", "--dataset", type=str, default="rcv1-2") parser.add_argument("--domains", type=str, default="all") parser.add_argument("--nrows", type=int, default=None) parser.add_argument("--min_count", type=int, default=10) parser.add_argument("--max_labels", type=int, default=50) parser.add_argument("--clf_type", type=str, default="multilabel") # gFUN parameters ---------------------- parser.add_argument("-p", "--posteriors", action="store_true") parser.add_argument("-m", "--multilingual", action="store_true") parser.add_argument("-w", "--wce", action="store_true") parser.add_argument("-t", "--textual_transformer", action="store_true") parser.add_argument("-v", "--visual_transformer", action="store_true") parser.add_argument("--n_jobs", type=int, default=-1) parser.add_argument("--optimc", action="store_true") parser.add_argument("--features", action="store_false") parser.add_argument("--aggfunc", type=str, default="mean") # transformer parameters --------------- parser.add_argument("--transformer_name", type=str, default="mbert") parser.add_argument("--batch_size", type=int, default=32) parser.add_argument("--eval_batch_size", type=int, default=128) parser.add_argument("--epochs", type=int, default=100) parser.add_argument("--textual_lr", type=float, default=1e-5) parser.add_argument("--visual_lr", type=float, default=1e-5) parser.add_argument("--max_length", type=int, default=128) parser.add_argument("--patience", type=int, default=5) parser.add_argument("--evaluate_step", type=int, default=10) # Visual Transformer parameters -------------- parser.add_argument("--visual_transformer_name", type=str, default="vit") args = parser.parse_args() main(args)