fixed TransformerGen init
This commit is contained in:
parent
b34da419d0
commit
ee38bcda10
|
@ -59,8 +59,8 @@ class GeneralizedFunnelling:
|
|||
# Textual Transformer VGF params ----------
|
||||
self.textual_trf_name = textual_transformer_name
|
||||
self.epochs = epochs
|
||||
self.txt_trf_lr = textual_lr
|
||||
self.vis_trf_lr = visual_lr
|
||||
self.textual_trf_lr = textual_lr
|
||||
self.textual_scheduler = "ReduceLROnPlateau"
|
||||
self.batch_size_trf = batch_size
|
||||
self.eval_batch_size_trf = eval_batch_size
|
||||
self.max_length = max_length
|
||||
|
@ -70,6 +70,8 @@ class GeneralizedFunnelling:
|
|||
self.device = device
|
||||
# Visual Transformer VGF params ----------
|
||||
self.visual_trf_name = visual_transformer_name
|
||||
self.visual_trf_lr = visual_lr
|
||||
self.visual_scheduler = "ReduceLROnPlateau"
|
||||
# Metaclassifier params ------------
|
||||
self.optimc = optimc
|
||||
# -------------------
|
||||
|
@ -115,7 +117,7 @@ class GeneralizedFunnelling:
|
|||
self.attn_aggregator = AttentionAggregator(
|
||||
embed_dim=self.get_attn_agg_dim(attn_stacking_type=attn_stacking),
|
||||
out_dim=self.num_labels,
|
||||
lr=self.txt_trf_lr,
|
||||
lr=self.textual_trf_lr,
|
||||
patience=self.patience,
|
||||
num_heads=1,
|
||||
device=self.device,
|
||||
|
@ -149,7 +151,8 @@ class GeneralizedFunnelling:
|
|||
transformer_vgf = TextualTransformerGen(
|
||||
dataset_name=self.dataset_name,
|
||||
model_name=self.textual_trf_name,
|
||||
lr=self.txt_trf_lr,
|
||||
lr=self.textual_trf_lr,
|
||||
scheduler=self.textual_scheduler,
|
||||
epochs=self.epochs,
|
||||
batch_size=self.batch_size_trf,
|
||||
batch_size_eval=self.eval_batch_size_trf,
|
||||
|
@ -168,7 +171,8 @@ class GeneralizedFunnelling:
|
|||
visual_trasformer_vgf = VisualTransformerGen(
|
||||
dataset_name=self.dataset_name,
|
||||
model_name="vit",
|
||||
lr=self.vis_trf_lr,
|
||||
lr=self.visual_trf_lr,
|
||||
scheduler=self.visual_scheduler,
|
||||
epochs=self.epochs,
|
||||
batch_size=self.batch_size_trf,
|
||||
batch_size_eval=self.eval_batch_size_trf,
|
||||
|
@ -185,7 +189,7 @@ class GeneralizedFunnelling:
|
|||
self.attn_aggregator = AttentionAggregator(
|
||||
embed_dim=self.get_attn_agg_dim(attn_stacking_type=attn_stacking),
|
||||
out_dim=self.num_labels,
|
||||
lr=self.txt_trf_lr,
|
||||
lr=self.textual_trf_lr,
|
||||
patience=self.patience,
|
||||
num_heads=1,
|
||||
device=self.device,
|
||||
|
@ -346,7 +350,7 @@ class GeneralizedFunnelling:
|
|||
pickle.dump(self.metaclassifier, f)
|
||||
return
|
||||
|
||||
def save_first_tier_learners(self):
|
||||
def save_first_tier_learners(self, model_id):
|
||||
for vgf in self.first_tier_learners:
|
||||
vgf.save_vgf(model_id=self._model_id)
|
||||
return self
|
||||
|
|
|
@ -70,22 +70,24 @@ class TextualTransformerGen(ViewGen, TransformerGen):
|
|||
verbose=False,
|
||||
patience=5,
|
||||
classification_type="multilabel",
|
||||
scheduler="ReduceLROnPlateau",
|
||||
):
|
||||
super().__init__(
|
||||
self._validate_model_name(model_name),
|
||||
dataset_name,
|
||||
epochs,
|
||||
lr,
|
||||
batch_size,
|
||||
batch_size_eval,
|
||||
max_length,
|
||||
print_steps,
|
||||
device,
|
||||
probabilistic,
|
||||
n_jobs,
|
||||
evaluate_step,
|
||||
verbose,
|
||||
patience,
|
||||
epochs=epochs,
|
||||
lr=lr,
|
||||
scheduler=scheduler,
|
||||
batch_size=batch_size,
|
||||
batch_size_eval=batch_size_eval,
|
||||
device=device,
|
||||
evaluate_step=evaluate_step,
|
||||
patience=patience,
|
||||
probabilistic=probabilistic,
|
||||
max_length=max_length,
|
||||
print_steps=print_steps,
|
||||
n_jobs=n_jobs,
|
||||
verbose=verbose,
|
||||
)
|
||||
self.clf_type = classification_type
|
||||
self.fitted = False
|
||||
|
@ -187,8 +189,7 @@ class TextualTransformerGen(ViewGen, TransformerGen):
|
|||
vgf_name="textual_trf",
|
||||
classification_type=self.clf_type,
|
||||
n_jobs=self.n_jobs,
|
||||
scheduler_name="ReduceLROnPlateau",
|
||||
# scheduler_name=None,
|
||||
scheduler_name=self.scheduler,
|
||||
)
|
||||
trainer.train(
|
||||
train_dataloader=tra_dataloader,
|
||||
|
|
|
@ -26,6 +26,7 @@ class TransformerGen:
|
|||
evaluate_step=10,
|
||||
verbose=False,
|
||||
patience=5,
|
||||
scheduler=None,
|
||||
):
|
||||
self.model_name = model_name
|
||||
self.dataset_name = dataset_name
|
||||
|
@ -46,6 +47,7 @@ class TransformerGen:
|
|||
self.verbose = verbose
|
||||
self.patience = patience
|
||||
self.datasets = {}
|
||||
self.scheduler = scheduler
|
||||
self.feature2posterior_projector = (
|
||||
self.make_probabilistic() if probabilistic else None
|
||||
)
|
||||
|
@ -101,6 +103,7 @@ class TransformerGen:
|
|||
"dataset_name": self.dataset_name,
|
||||
"epochs": self.epochs,
|
||||
"lr": self.lr,
|
||||
"scheduler": self.scheduler,
|
||||
"batch_size": self.batch_size,
|
||||
"batch_size_eval": self.batch_size_eval,
|
||||
"max_length": self.max_length,
|
||||
|
@ -111,4 +114,4 @@ class TransformerGen:
|
|||
"evaluate_step": self.evaluate_step,
|
||||
"verbose": self.verbose,
|
||||
"patience": self.patience,
|
||||
}
|
||||
}
|
||||
|
|
|
@ -20,6 +20,7 @@ class VisualTransformerGen(ViewGen, TransformerGen):
|
|||
model_name,
|
||||
dataset_name,
|
||||
lr=1e-5,
|
||||
scheduler="ReduceLROnPlateau",
|
||||
epochs=10,
|
||||
batch_size=32,
|
||||
batch_size_eval=128,
|
||||
|
@ -32,8 +33,9 @@ class VisualTransformerGen(ViewGen, TransformerGen):
|
|||
super().__init__(
|
||||
model_name,
|
||||
dataset_name,
|
||||
lr=lr,
|
||||
epochs=epochs,
|
||||
lr=lr,
|
||||
scheduler=scheduler,
|
||||
batch_size=batch_size,
|
||||
batch_size_eval=batch_size_eval,
|
||||
device=device,
|
||||
|
|
Loading…
Reference in New Issue