fixed TransformerGen init

This commit is contained in:
Andrea Pedrotti 2023-03-16 12:12:39 +01:00
parent b34da419d0
commit ee38bcda10
4 changed files with 33 additions and 23 deletions

View File

@ -59,8 +59,8 @@ class GeneralizedFunnelling:
# Textual Transformer VGF params ---------- # Textual Transformer VGF params ----------
self.textual_trf_name = textual_transformer_name self.textual_trf_name = textual_transformer_name
self.epochs = epochs self.epochs = epochs
self.txt_trf_lr = textual_lr self.textual_trf_lr = textual_lr
self.vis_trf_lr = visual_lr self.textual_scheduler = "ReduceLROnPlateau"
self.batch_size_trf = batch_size self.batch_size_trf = batch_size
self.eval_batch_size_trf = eval_batch_size self.eval_batch_size_trf = eval_batch_size
self.max_length = max_length self.max_length = max_length
@ -70,6 +70,8 @@ class GeneralizedFunnelling:
self.device = device self.device = device
# Visual Transformer VGF params ---------- # Visual Transformer VGF params ----------
self.visual_trf_name = visual_transformer_name self.visual_trf_name = visual_transformer_name
self.visual_trf_lr = visual_lr
self.visual_scheduler = "ReduceLROnPlateau"
# Metaclassifier params ------------ # Metaclassifier params ------------
self.optimc = optimc self.optimc = optimc
# ------------------- # -------------------
@ -115,7 +117,7 @@ class GeneralizedFunnelling:
self.attn_aggregator = AttentionAggregator( self.attn_aggregator = AttentionAggregator(
embed_dim=self.get_attn_agg_dim(attn_stacking_type=attn_stacking), embed_dim=self.get_attn_agg_dim(attn_stacking_type=attn_stacking),
out_dim=self.num_labels, out_dim=self.num_labels,
lr=self.txt_trf_lr, lr=self.textual_trf_lr,
patience=self.patience, patience=self.patience,
num_heads=1, num_heads=1,
device=self.device, device=self.device,
@ -149,7 +151,8 @@ class GeneralizedFunnelling:
transformer_vgf = TextualTransformerGen( transformer_vgf = TextualTransformerGen(
dataset_name=self.dataset_name, dataset_name=self.dataset_name,
model_name=self.textual_trf_name, model_name=self.textual_trf_name,
lr=self.txt_trf_lr, lr=self.textual_trf_lr,
scheduler=self.textual_scheduler,
epochs=self.epochs, epochs=self.epochs,
batch_size=self.batch_size_trf, batch_size=self.batch_size_trf,
batch_size_eval=self.eval_batch_size_trf, batch_size_eval=self.eval_batch_size_trf,
@ -168,7 +171,8 @@ class GeneralizedFunnelling:
visual_trasformer_vgf = VisualTransformerGen( visual_trasformer_vgf = VisualTransformerGen(
dataset_name=self.dataset_name, dataset_name=self.dataset_name,
model_name="vit", model_name="vit",
lr=self.vis_trf_lr, lr=self.visual_trf_lr,
scheduler=self.visual_scheduler,
epochs=self.epochs, epochs=self.epochs,
batch_size=self.batch_size_trf, batch_size=self.batch_size_trf,
batch_size_eval=self.eval_batch_size_trf, batch_size_eval=self.eval_batch_size_trf,
@ -185,7 +189,7 @@ class GeneralizedFunnelling:
self.attn_aggregator = AttentionAggregator( self.attn_aggregator = AttentionAggregator(
embed_dim=self.get_attn_agg_dim(attn_stacking_type=attn_stacking), embed_dim=self.get_attn_agg_dim(attn_stacking_type=attn_stacking),
out_dim=self.num_labels, out_dim=self.num_labels,
lr=self.txt_trf_lr, lr=self.textual_trf_lr,
patience=self.patience, patience=self.patience,
num_heads=1, num_heads=1,
device=self.device, device=self.device,
@ -346,7 +350,7 @@ class GeneralizedFunnelling:
pickle.dump(self.metaclassifier, f) pickle.dump(self.metaclassifier, f)
return return
def save_first_tier_learners(self): def save_first_tier_learners(self, model_id):
for vgf in self.first_tier_learners: for vgf in self.first_tier_learners:
vgf.save_vgf(model_id=self._model_id) vgf.save_vgf(model_id=self._model_id)
return self return self

View File

@ -70,22 +70,24 @@ class TextualTransformerGen(ViewGen, TransformerGen):
verbose=False, verbose=False,
patience=5, patience=5,
classification_type="multilabel", classification_type="multilabel",
scheduler="ReduceLROnPlateau",
): ):
super().__init__( super().__init__(
self._validate_model_name(model_name), self._validate_model_name(model_name),
dataset_name, dataset_name,
epochs, epochs=epochs,
lr, lr=lr,
batch_size, scheduler=scheduler,
batch_size_eval, batch_size=batch_size,
max_length, batch_size_eval=batch_size_eval,
print_steps, device=device,
device, evaluate_step=evaluate_step,
probabilistic, patience=patience,
n_jobs, probabilistic=probabilistic,
evaluate_step, max_length=max_length,
verbose, print_steps=print_steps,
patience, n_jobs=n_jobs,
verbose=verbose,
) )
self.clf_type = classification_type self.clf_type = classification_type
self.fitted = False self.fitted = False
@ -187,8 +189,7 @@ class TextualTransformerGen(ViewGen, TransformerGen):
vgf_name="textual_trf", vgf_name="textual_trf",
classification_type=self.clf_type, classification_type=self.clf_type,
n_jobs=self.n_jobs, n_jobs=self.n_jobs,
scheduler_name="ReduceLROnPlateau", scheduler_name=self.scheduler,
# scheduler_name=None,
) )
trainer.train( trainer.train(
train_dataloader=tra_dataloader, train_dataloader=tra_dataloader,

View File

@ -26,6 +26,7 @@ class TransformerGen:
evaluate_step=10, evaluate_step=10,
verbose=False, verbose=False,
patience=5, patience=5,
scheduler=None,
): ):
self.model_name = model_name self.model_name = model_name
self.dataset_name = dataset_name self.dataset_name = dataset_name
@ -46,6 +47,7 @@ class TransformerGen:
self.verbose = verbose self.verbose = verbose
self.patience = patience self.patience = patience
self.datasets = {} self.datasets = {}
self.scheduler = scheduler
self.feature2posterior_projector = ( self.feature2posterior_projector = (
self.make_probabilistic() if probabilistic else None self.make_probabilistic() if probabilistic else None
) )
@ -101,6 +103,7 @@ class TransformerGen:
"dataset_name": self.dataset_name, "dataset_name": self.dataset_name,
"epochs": self.epochs, "epochs": self.epochs,
"lr": self.lr, "lr": self.lr,
"scheduler": self.scheduler,
"batch_size": self.batch_size, "batch_size": self.batch_size,
"batch_size_eval": self.batch_size_eval, "batch_size_eval": self.batch_size_eval,
"max_length": self.max_length, "max_length": self.max_length,
@ -111,4 +114,4 @@ class TransformerGen:
"evaluate_step": self.evaluate_step, "evaluate_step": self.evaluate_step,
"verbose": self.verbose, "verbose": self.verbose,
"patience": self.patience, "patience": self.patience,
} }

View File

@ -20,6 +20,7 @@ class VisualTransformerGen(ViewGen, TransformerGen):
model_name, model_name,
dataset_name, dataset_name,
lr=1e-5, lr=1e-5,
scheduler="ReduceLROnPlateau",
epochs=10, epochs=10,
batch_size=32, batch_size=32,
batch_size_eval=128, batch_size_eval=128,
@ -32,8 +33,9 @@ class VisualTransformerGen(ViewGen, TransformerGen):
super().__init__( super().__init__(
model_name, model_name,
dataset_name, dataset_name,
lr=lr,
epochs=epochs, epochs=epochs,
lr=lr,
scheduler=scheduler,
batch_size=batch_size, batch_size=batch_size,
batch_size_eval=batch_size_eval, batch_size_eval=batch_size_eval,
device=device, device=device,