diff --git a/TweetSentQuant/experiments.py b/TweetSentQuant/experiments.py index 5c0ceca..93ca993 100644 --- a/TweetSentQuant/experiments.py +++ b/TweetSentQuant/experiments.py @@ -41,7 +41,7 @@ def quantification_models(): device = 'cuda' if torch.cuda.is_available() else 'cpu' print(f'Running QuaNet in {device}') - yield 'quanet', QuaNet(PCALR(**newLR().get_params()), settings.SAMPLE_SIZE, checkpointdir=args.checkpointdir, device=device), lr_params + yield 'quanet', QuaNet(PCALR(**newLR().get_params()), settings.SAMPLE_SIZE, tr_iter_per_poch=500, va_iter_per_poch=100, checkpointdir=args.checkpointdir, device=device), lr_params param_mod_sel={'sample_size':settings.SAMPLE_SIZE, 'n_prevpoints':21, 'n_repetitions':5} #yield 'epaccmaeptr', EPACC(newLR(), param_grid=lr_params, optim='mae', policy='ptr', param_mod_sel=param_mod_sel, n_jobs=settings.ENSEMBLE_N_JOBS), None