fixed TransformerGen init

This commit is contained in:
Andrea Pedrotti 2023-03-16 12:12:39 +01:00
parent b34da419d0
commit ee38bcda10
4 changed files with 33 additions and 23 deletions

View File

@ -59,8 +59,8 @@ class GeneralizedFunnelling:
# Textual Transformer VGF params ----------
self.textual_trf_name = textual_transformer_name
self.epochs = epochs
self.txt_trf_lr = textual_lr
self.vis_trf_lr = visual_lr
self.textual_trf_lr = textual_lr
self.textual_scheduler = "ReduceLROnPlateau"
self.batch_size_trf = batch_size
self.eval_batch_size_trf = eval_batch_size
self.max_length = max_length
@ -70,6 +70,8 @@ class GeneralizedFunnelling:
self.device = device
# Visual Transformer VGF params ----------
self.visual_trf_name = visual_transformer_name
self.visual_trf_lr = visual_lr
self.visual_scheduler = "ReduceLROnPlateau"
# Metaclassifier params ------------
self.optimc = optimc
# -------------------
@ -115,7 +117,7 @@ class GeneralizedFunnelling:
self.attn_aggregator = AttentionAggregator(
embed_dim=self.get_attn_agg_dim(attn_stacking_type=attn_stacking),
out_dim=self.num_labels,
lr=self.txt_trf_lr,
lr=self.textual_trf_lr,
patience=self.patience,
num_heads=1,
device=self.device,
@ -149,7 +151,8 @@ class GeneralizedFunnelling:
transformer_vgf = TextualTransformerGen(
dataset_name=self.dataset_name,
model_name=self.textual_trf_name,
lr=self.txt_trf_lr,
lr=self.textual_trf_lr,
scheduler=self.textual_scheduler,
epochs=self.epochs,
batch_size=self.batch_size_trf,
batch_size_eval=self.eval_batch_size_trf,
@ -168,7 +171,8 @@ class GeneralizedFunnelling:
visual_trasformer_vgf = VisualTransformerGen(
dataset_name=self.dataset_name,
model_name="vit",
lr=self.vis_trf_lr,
lr=self.visual_trf_lr,
scheduler=self.visual_scheduler,
epochs=self.epochs,
batch_size=self.batch_size_trf,
batch_size_eval=self.eval_batch_size_trf,
@ -185,7 +189,7 @@ class GeneralizedFunnelling:
self.attn_aggregator = AttentionAggregator(
embed_dim=self.get_attn_agg_dim(attn_stacking_type=attn_stacking),
out_dim=self.num_labels,
lr=self.txt_trf_lr,
lr=self.textual_trf_lr,
patience=self.patience,
num_heads=1,
device=self.device,
@ -346,7 +350,7 @@ class GeneralizedFunnelling:
pickle.dump(self.metaclassifier, f)
return
def save_first_tier_learners(self):
def save_first_tier_learners(self, model_id):
for vgf in self.first_tier_learners:
vgf.save_vgf(model_id=self._model_id)
return self

View File

@ -70,22 +70,24 @@ class TextualTransformerGen(ViewGen, TransformerGen):
verbose=False,
patience=5,
classification_type="multilabel",
scheduler="ReduceLROnPlateau",
):
super().__init__(
self._validate_model_name(model_name),
dataset_name,
epochs,
lr,
batch_size,
batch_size_eval,
max_length,
print_steps,
device,
probabilistic,
n_jobs,
evaluate_step,
verbose,
patience,
epochs=epochs,
lr=lr,
scheduler=scheduler,
batch_size=batch_size,
batch_size_eval=batch_size_eval,
device=device,
evaluate_step=evaluate_step,
patience=patience,
probabilistic=probabilistic,
max_length=max_length,
print_steps=print_steps,
n_jobs=n_jobs,
verbose=verbose,
)
self.clf_type = classification_type
self.fitted = False
@ -187,8 +189,7 @@ class TextualTransformerGen(ViewGen, TransformerGen):
vgf_name="textual_trf",
classification_type=self.clf_type,
n_jobs=self.n_jobs,
scheduler_name="ReduceLROnPlateau",
# scheduler_name=None,
scheduler_name=self.scheduler,
)
trainer.train(
train_dataloader=tra_dataloader,

View File

@ -26,6 +26,7 @@ class TransformerGen:
evaluate_step=10,
verbose=False,
patience=5,
scheduler=None,
):
self.model_name = model_name
self.dataset_name = dataset_name
@ -46,6 +47,7 @@ class TransformerGen:
self.verbose = verbose
self.patience = patience
self.datasets = {}
self.scheduler = scheduler
self.feature2posterior_projector = (
self.make_probabilistic() if probabilistic else None
)
@ -101,6 +103,7 @@ class TransformerGen:
"dataset_name": self.dataset_name,
"epochs": self.epochs,
"lr": self.lr,
"scheduler": self.scheduler,
"batch_size": self.batch_size,
"batch_size_eval": self.batch_size_eval,
"max_length": self.max_length,
@ -111,4 +114,4 @@ class TransformerGen:
"evaluate_step": self.evaluate_step,
"verbose": self.verbose,
"patience": self.patience,
}
}

View File

@ -20,6 +20,7 @@ class VisualTransformerGen(ViewGen, TransformerGen):
model_name,
dataset_name,
lr=1e-5,
scheduler="ReduceLROnPlateau",
epochs=10,
batch_size=32,
batch_size_eval=128,
@ -32,8 +33,9 @@ class VisualTransformerGen(ViewGen, TransformerGen):
super().__init__(
model_name,
dataset_name,
lr=lr,
epochs=epochs,
lr=lr,
scheduler=scheduler,
batch_size=batch_size,
batch_size_eval=batch_size_eval,
device=device,