updated argparse
This commit is contained in:
parent
5e41b4517a
commit
fece8d059e
10
main.py
10
main.py
|
@ -58,7 +58,7 @@ def main(args):
|
|||
wce=args.wce,
|
||||
# Transformer VGF params --------------
|
||||
textual_transformer=args.textual_transformer,
|
||||
textual_transformer_name=args.transformer_name,
|
||||
textual_transformer_name=args.textual_trf_name,
|
||||
batch_size=args.batch_size,
|
||||
eval_batch_size=args.eval_batch_size,
|
||||
epochs=args.epochs,
|
||||
|
@ -70,14 +70,14 @@ def main(args):
|
|||
device=args.device,
|
||||
# Visual Transformer VGF params --------------
|
||||
visual_transformer=args.visual_transformer,
|
||||
visual_transformer_name=args.visual_transformer_name,
|
||||
visual_transformer_name=args.visual_trf_name,
|
||||
# batch_size=args.batch_size,
|
||||
# epochs=args.epochs,
|
||||
# lr=args.lr,
|
||||
# patience=args.patience,
|
||||
# evaluate_step=args.evaluate_step,
|
||||
# device="cuda",
|
||||
# General params ----------------------
|
||||
# General params ---------------------
|
||||
probabilistic=args.features,
|
||||
aggfunc=args.aggfunc,
|
||||
optimc=args.optimc,
|
||||
|
@ -133,7 +133,7 @@ if __name__ == "__main__":
|
|||
parser.add_argument("--features", action="store_false")
|
||||
parser.add_argument("--aggfunc", type=str, default="mean")
|
||||
# transformer parameters ---------------
|
||||
parser.add_argument("--transformer_name", type=str, default="mbert")
|
||||
parser.add_argument("--textual_trf_name", type=str, default="mbert")
|
||||
parser.add_argument("--batch_size", type=int, default=32)
|
||||
parser.add_argument("--eval_batch_size", type=int, default=128)
|
||||
parser.add_argument("--epochs", type=int, default=100)
|
||||
|
@ -143,7 +143,7 @@ if __name__ == "__main__":
|
|||
parser.add_argument("--patience", type=int, default=5)
|
||||
parser.add_argument("--evaluate_step", type=int, default=10)
|
||||
# Visual Transformer parameters --------------
|
||||
parser.add_argument("--visual_transformer_name", type=str, default="vit")
|
||||
parser.add_argument("--visual_trf_name", type=str, default="vit")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
|
|
Loading…
Reference in New Issue