diff --git a/gfun/generalizedFunnelling.py b/gfun/generalizedFunnelling.py index 71c9054..14efc6f 100644 --- a/gfun/generalizedFunnelling.py +++ b/gfun/generalizedFunnelling.py @@ -59,8 +59,8 @@ class GeneralizedFunnelling: # Textual Transformer VGF params ---------- self.textual_trf_name = textual_transformer_name self.epochs = epochs - self.txt_trf_lr = textual_lr - self.vis_trf_lr = visual_lr + self.textual_trf_lr = textual_lr + self.textual_scheduler = "ReduceLROnPlateau" self.batch_size_trf = batch_size self.eval_batch_size_trf = eval_batch_size self.max_length = max_length @@ -70,6 +70,8 @@ class GeneralizedFunnelling: self.device = device # Visual Transformer VGF params ---------- self.visual_trf_name = visual_transformer_name + self.visual_trf_lr = visual_lr + self.visual_scheduler = "ReduceLROnPlateau" # Metaclassifier params ------------ self.optimc = optimc # ------------------- @@ -115,7 +117,7 @@ class GeneralizedFunnelling: self.attn_aggregator = AttentionAggregator( embed_dim=self.get_attn_agg_dim(attn_stacking_type=attn_stacking), out_dim=self.num_labels, - lr=self.txt_trf_lr, + lr=self.textual_trf_lr, patience=self.patience, num_heads=1, device=self.device, @@ -149,7 +151,8 @@ class GeneralizedFunnelling: transformer_vgf = TextualTransformerGen( dataset_name=self.dataset_name, model_name=self.textual_trf_name, - lr=self.txt_trf_lr, + lr=self.textual_trf_lr, + scheduler=self.textual_scheduler, epochs=self.epochs, batch_size=self.batch_size_trf, batch_size_eval=self.eval_batch_size_trf, @@ -168,7 +171,8 @@ class GeneralizedFunnelling: visual_trasformer_vgf = VisualTransformerGen( dataset_name=self.dataset_name, model_name="vit", - lr=self.vis_trf_lr, + lr=self.visual_trf_lr, + scheduler=self.visual_scheduler, epochs=self.epochs, batch_size=self.batch_size_trf, batch_size_eval=self.eval_batch_size_trf, @@ -185,7 +189,7 @@ class GeneralizedFunnelling: self.attn_aggregator = AttentionAggregator( embed_dim=self.get_attn_agg_dim(attn_stacking_type=attn_stacking), out_dim=self.num_labels, - lr=self.txt_trf_lr, + lr=self.textual_trf_lr, patience=self.patience, num_heads=1, device=self.device, @@ -346,7 +350,7 @@ class GeneralizedFunnelling: pickle.dump(self.metaclassifier, f) return - def save_first_tier_learners(self): + def save_first_tier_learners(self, model_id): for vgf in self.first_tier_learners: vgf.save_vgf(model_id=self._model_id) return self diff --git a/gfun/vgfs/textualTransformerGen.py b/gfun/vgfs/textualTransformerGen.py index 16b70ed..68a358e 100644 --- a/gfun/vgfs/textualTransformerGen.py +++ b/gfun/vgfs/textualTransformerGen.py @@ -70,22 +70,24 @@ class TextualTransformerGen(ViewGen, TransformerGen): verbose=False, patience=5, classification_type="multilabel", + scheduler="ReduceLROnPlateau", ): super().__init__( self._validate_model_name(model_name), dataset_name, - epochs, - lr, - batch_size, - batch_size_eval, - max_length, - print_steps, - device, - probabilistic, - n_jobs, - evaluate_step, - verbose, - patience, + epochs=epochs, + lr=lr, + scheduler=scheduler, + batch_size=batch_size, + batch_size_eval=batch_size_eval, + device=device, + evaluate_step=evaluate_step, + patience=patience, + probabilistic=probabilistic, + max_length=max_length, + print_steps=print_steps, + n_jobs=n_jobs, + verbose=verbose, ) self.clf_type = classification_type self.fitted = False @@ -187,8 +189,7 @@ class TextualTransformerGen(ViewGen, TransformerGen): vgf_name="textual_trf", classification_type=self.clf_type, n_jobs=self.n_jobs, - scheduler_name="ReduceLROnPlateau", - # scheduler_name=None, + scheduler_name=self.scheduler, ) trainer.train( train_dataloader=tra_dataloader, diff --git a/gfun/vgfs/transformerGen.py b/gfun/vgfs/transformerGen.py index d3d09a3..77e4bbc 100644 --- a/gfun/vgfs/transformerGen.py +++ b/gfun/vgfs/transformerGen.py @@ -26,6 +26,7 @@ class TransformerGen: evaluate_step=10, verbose=False, patience=5, + scheduler=None, ): self.model_name = model_name self.dataset_name = dataset_name @@ -46,6 +47,7 @@ class TransformerGen: self.verbose = verbose self.patience = patience self.datasets = {} + self.scheduler = scheduler self.feature2posterior_projector = ( self.make_probabilistic() if probabilistic else None ) @@ -101,6 +103,7 @@ class TransformerGen: "dataset_name": self.dataset_name, "epochs": self.epochs, "lr": self.lr, + "scheduler": self.scheduler, "batch_size": self.batch_size, "batch_size_eval": self.batch_size_eval, "max_length": self.max_length, @@ -111,4 +114,4 @@ class TransformerGen: "evaluate_step": self.evaluate_step, "verbose": self.verbose, "patience": self.patience, - } \ No newline at end of file + } diff --git a/gfun/vgfs/visualTransformerGen.py b/gfun/vgfs/visualTransformerGen.py index c105003..eb6fdf7 100644 --- a/gfun/vgfs/visualTransformerGen.py +++ b/gfun/vgfs/visualTransformerGen.py @@ -20,6 +20,7 @@ class VisualTransformerGen(ViewGen, TransformerGen): model_name, dataset_name, lr=1e-5, + scheduler="ReduceLROnPlateau", epochs=10, batch_size=32, batch_size_eval=128, @@ -32,8 +33,9 @@ class VisualTransformerGen(ViewGen, TransformerGen): super().__init__( model_name, dataset_name, - lr=lr, epochs=epochs, + lr=lr, + scheduler=scheduler, batch_size=batch_size, batch_size_eval=batch_size_eval, device=device,