fixed TransformerGen init
This commit is contained in:
parent
b34da419d0
commit
ee38bcda10
|
@ -59,8 +59,8 @@ class GeneralizedFunnelling:
|
||||||
# Textual Transformer VGF params ----------
|
# Textual Transformer VGF params ----------
|
||||||
self.textual_trf_name = textual_transformer_name
|
self.textual_trf_name = textual_transformer_name
|
||||||
self.epochs = epochs
|
self.epochs = epochs
|
||||||
self.txt_trf_lr = textual_lr
|
self.textual_trf_lr = textual_lr
|
||||||
self.vis_trf_lr = visual_lr
|
self.textual_scheduler = "ReduceLROnPlateau"
|
||||||
self.batch_size_trf = batch_size
|
self.batch_size_trf = batch_size
|
||||||
self.eval_batch_size_trf = eval_batch_size
|
self.eval_batch_size_trf = eval_batch_size
|
||||||
self.max_length = max_length
|
self.max_length = max_length
|
||||||
|
@ -70,6 +70,8 @@ class GeneralizedFunnelling:
|
||||||
self.device = device
|
self.device = device
|
||||||
# Visual Transformer VGF params ----------
|
# Visual Transformer VGF params ----------
|
||||||
self.visual_trf_name = visual_transformer_name
|
self.visual_trf_name = visual_transformer_name
|
||||||
|
self.visual_trf_lr = visual_lr
|
||||||
|
self.visual_scheduler = "ReduceLROnPlateau"
|
||||||
# Metaclassifier params ------------
|
# Metaclassifier params ------------
|
||||||
self.optimc = optimc
|
self.optimc = optimc
|
||||||
# -------------------
|
# -------------------
|
||||||
|
@ -115,7 +117,7 @@ class GeneralizedFunnelling:
|
||||||
self.attn_aggregator = AttentionAggregator(
|
self.attn_aggregator = AttentionAggregator(
|
||||||
embed_dim=self.get_attn_agg_dim(attn_stacking_type=attn_stacking),
|
embed_dim=self.get_attn_agg_dim(attn_stacking_type=attn_stacking),
|
||||||
out_dim=self.num_labels,
|
out_dim=self.num_labels,
|
||||||
lr=self.txt_trf_lr,
|
lr=self.textual_trf_lr,
|
||||||
patience=self.patience,
|
patience=self.patience,
|
||||||
num_heads=1,
|
num_heads=1,
|
||||||
device=self.device,
|
device=self.device,
|
||||||
|
@ -149,7 +151,8 @@ class GeneralizedFunnelling:
|
||||||
transformer_vgf = TextualTransformerGen(
|
transformer_vgf = TextualTransformerGen(
|
||||||
dataset_name=self.dataset_name,
|
dataset_name=self.dataset_name,
|
||||||
model_name=self.textual_trf_name,
|
model_name=self.textual_trf_name,
|
||||||
lr=self.txt_trf_lr,
|
lr=self.textual_trf_lr,
|
||||||
|
scheduler=self.textual_scheduler,
|
||||||
epochs=self.epochs,
|
epochs=self.epochs,
|
||||||
batch_size=self.batch_size_trf,
|
batch_size=self.batch_size_trf,
|
||||||
batch_size_eval=self.eval_batch_size_trf,
|
batch_size_eval=self.eval_batch_size_trf,
|
||||||
|
@ -168,7 +171,8 @@ class GeneralizedFunnelling:
|
||||||
visual_trasformer_vgf = VisualTransformerGen(
|
visual_trasformer_vgf = VisualTransformerGen(
|
||||||
dataset_name=self.dataset_name,
|
dataset_name=self.dataset_name,
|
||||||
model_name="vit",
|
model_name="vit",
|
||||||
lr=self.vis_trf_lr,
|
lr=self.visual_trf_lr,
|
||||||
|
scheduler=self.visual_scheduler,
|
||||||
epochs=self.epochs,
|
epochs=self.epochs,
|
||||||
batch_size=self.batch_size_trf,
|
batch_size=self.batch_size_trf,
|
||||||
batch_size_eval=self.eval_batch_size_trf,
|
batch_size_eval=self.eval_batch_size_trf,
|
||||||
|
@ -185,7 +189,7 @@ class GeneralizedFunnelling:
|
||||||
self.attn_aggregator = AttentionAggregator(
|
self.attn_aggregator = AttentionAggregator(
|
||||||
embed_dim=self.get_attn_agg_dim(attn_stacking_type=attn_stacking),
|
embed_dim=self.get_attn_agg_dim(attn_stacking_type=attn_stacking),
|
||||||
out_dim=self.num_labels,
|
out_dim=self.num_labels,
|
||||||
lr=self.txt_trf_lr,
|
lr=self.textual_trf_lr,
|
||||||
patience=self.patience,
|
patience=self.patience,
|
||||||
num_heads=1,
|
num_heads=1,
|
||||||
device=self.device,
|
device=self.device,
|
||||||
|
@ -346,7 +350,7 @@ class GeneralizedFunnelling:
|
||||||
pickle.dump(self.metaclassifier, f)
|
pickle.dump(self.metaclassifier, f)
|
||||||
return
|
return
|
||||||
|
|
||||||
def save_first_tier_learners(self):
|
def save_first_tier_learners(self, model_id):
|
||||||
for vgf in self.first_tier_learners:
|
for vgf in self.first_tier_learners:
|
||||||
vgf.save_vgf(model_id=self._model_id)
|
vgf.save_vgf(model_id=self._model_id)
|
||||||
return self
|
return self
|
||||||
|
|
|
@ -70,22 +70,24 @@ class TextualTransformerGen(ViewGen, TransformerGen):
|
||||||
verbose=False,
|
verbose=False,
|
||||||
patience=5,
|
patience=5,
|
||||||
classification_type="multilabel",
|
classification_type="multilabel",
|
||||||
|
scheduler="ReduceLROnPlateau",
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
self._validate_model_name(model_name),
|
self._validate_model_name(model_name),
|
||||||
dataset_name,
|
dataset_name,
|
||||||
epochs,
|
epochs=epochs,
|
||||||
lr,
|
lr=lr,
|
||||||
batch_size,
|
scheduler=scheduler,
|
||||||
batch_size_eval,
|
batch_size=batch_size,
|
||||||
max_length,
|
batch_size_eval=batch_size_eval,
|
||||||
print_steps,
|
device=device,
|
||||||
device,
|
evaluate_step=evaluate_step,
|
||||||
probabilistic,
|
patience=patience,
|
||||||
n_jobs,
|
probabilistic=probabilistic,
|
||||||
evaluate_step,
|
max_length=max_length,
|
||||||
verbose,
|
print_steps=print_steps,
|
||||||
patience,
|
n_jobs=n_jobs,
|
||||||
|
verbose=verbose,
|
||||||
)
|
)
|
||||||
self.clf_type = classification_type
|
self.clf_type = classification_type
|
||||||
self.fitted = False
|
self.fitted = False
|
||||||
|
@ -187,8 +189,7 @@ class TextualTransformerGen(ViewGen, TransformerGen):
|
||||||
vgf_name="textual_trf",
|
vgf_name="textual_trf",
|
||||||
classification_type=self.clf_type,
|
classification_type=self.clf_type,
|
||||||
n_jobs=self.n_jobs,
|
n_jobs=self.n_jobs,
|
||||||
scheduler_name="ReduceLROnPlateau",
|
scheduler_name=self.scheduler,
|
||||||
# scheduler_name=None,
|
|
||||||
)
|
)
|
||||||
trainer.train(
|
trainer.train(
|
||||||
train_dataloader=tra_dataloader,
|
train_dataloader=tra_dataloader,
|
||||||
|
|
|
@ -26,6 +26,7 @@ class TransformerGen:
|
||||||
evaluate_step=10,
|
evaluate_step=10,
|
||||||
verbose=False,
|
verbose=False,
|
||||||
patience=5,
|
patience=5,
|
||||||
|
scheduler=None,
|
||||||
):
|
):
|
||||||
self.model_name = model_name
|
self.model_name = model_name
|
||||||
self.dataset_name = dataset_name
|
self.dataset_name = dataset_name
|
||||||
|
@ -46,6 +47,7 @@ class TransformerGen:
|
||||||
self.verbose = verbose
|
self.verbose = verbose
|
||||||
self.patience = patience
|
self.patience = patience
|
||||||
self.datasets = {}
|
self.datasets = {}
|
||||||
|
self.scheduler = scheduler
|
||||||
self.feature2posterior_projector = (
|
self.feature2posterior_projector = (
|
||||||
self.make_probabilistic() if probabilistic else None
|
self.make_probabilistic() if probabilistic else None
|
||||||
)
|
)
|
||||||
|
@ -101,6 +103,7 @@ class TransformerGen:
|
||||||
"dataset_name": self.dataset_name,
|
"dataset_name": self.dataset_name,
|
||||||
"epochs": self.epochs,
|
"epochs": self.epochs,
|
||||||
"lr": self.lr,
|
"lr": self.lr,
|
||||||
|
"scheduler": self.scheduler,
|
||||||
"batch_size": self.batch_size,
|
"batch_size": self.batch_size,
|
||||||
"batch_size_eval": self.batch_size_eval,
|
"batch_size_eval": self.batch_size_eval,
|
||||||
"max_length": self.max_length,
|
"max_length": self.max_length,
|
||||||
|
@ -111,4 +114,4 @@ class TransformerGen:
|
||||||
"evaluate_step": self.evaluate_step,
|
"evaluate_step": self.evaluate_step,
|
||||||
"verbose": self.verbose,
|
"verbose": self.verbose,
|
||||||
"patience": self.patience,
|
"patience": self.patience,
|
||||||
}
|
}
|
||||||
|
|
|
@ -20,6 +20,7 @@ class VisualTransformerGen(ViewGen, TransformerGen):
|
||||||
model_name,
|
model_name,
|
||||||
dataset_name,
|
dataset_name,
|
||||||
lr=1e-5,
|
lr=1e-5,
|
||||||
|
scheduler="ReduceLROnPlateau",
|
||||||
epochs=10,
|
epochs=10,
|
||||||
batch_size=32,
|
batch_size=32,
|
||||||
batch_size_eval=128,
|
batch_size_eval=128,
|
||||||
|
@ -32,8 +33,9 @@ class VisualTransformerGen(ViewGen, TransformerGen):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
model_name,
|
model_name,
|
||||||
dataset_name,
|
dataset_name,
|
||||||
lr=lr,
|
|
||||||
epochs=epochs,
|
epochs=epochs,
|
||||||
|
lr=lr,
|
||||||
|
scheduler=scheduler,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
batch_size_eval=batch_size_eval,
|
batch_size_eval=batch_size_eval,
|
||||||
device=device,
|
device=device,
|
||||||
|
|
Loading…
Reference in New Issue