This commit is contained in:
Andrea Pedrotti 2023-03-07 14:33:30 +01:00
parent 7dead90271
commit 6b7917ca47
1 changed files with 10 additions and 10 deletions

View File

@ -47,8 +47,8 @@ class GeneralizedFunnelling:
self.posteriors_vgf = posterior
self.wce_vgf = wce
self.multilingual_vgf = multilingual
self.textual_trasformer_vgf = textual_transformer
self.visual_transformer_vgf = visual_transformer
self.textual_trf_vgf = textual_transformer
self.visual_trf_vgf = visual_transformer
self.probabilistic = probabilistic
self.num_labels = num_labels
# ------------------------
@ -56,7 +56,7 @@ class GeneralizedFunnelling:
self.embed_dir = embed_dir
self.cached = True
# Textual Transformer VGF params ----------
self.textaul_transformer_name = textual_transformer_name
self.textual_trf_name = textual_transformer_name
self.epochs = epochs
self.lr_transformer = lr
self.batch_size_transformer = batch_size
@ -66,7 +66,7 @@ class GeneralizedFunnelling:
self.evaluate_step = evaluate_step
self.device = device
# Visual Transformer VGF params ----------
self.visual_transformer_name = visual_transformer_name
self.visual_trf_name = visual_transformer_name
# Metaclassifier params ------------
self.optimc = optimc
# -------------------
@ -142,10 +142,10 @@ class GeneralizedFunnelling:
wce_vgf = WceGen(n_jobs=self.n_jobs)
self.first_tier_learners.append(wce_vgf)
if self.textual_trasformer_vgf:
if self.textual_trf_vgf:
transformer_vgf = TextualTransformerGen(
dataset_name=self.dataset_name,
model_name=self.textaul_transformer_name,
model_name=self.textual_trf_name,
lr=self.lr_transformer,
epochs=self.epochs,
batch_size=self.batch_size_transformer,
@ -159,7 +159,7 @@ class GeneralizedFunnelling:
)
self.first_tier_learners.append(transformer_vgf)
if self.visual_transformer_vgf:
if self.visual_trf_vgf:
visual_trasformer_vgf = VisualTransformerGen(
dataset_name=self.dataset_name,
model_name="vit",
@ -198,8 +198,8 @@ class GeneralizedFunnelling:
self.posteriors_vgf,
self.multilingual_vgf,
self.wce_vgf,
self.textual_trasformer_vgf,
self.visual_transformer_vgf,
self.textual_trf_vgf,
self.visual_trf_vgf,
self.aggfunc,
)
print(f"- model id: {self._model_id}")
@ -373,7 +373,7 @@ class GeneralizedFunnelling:
"rb",
) as vgf:
first_tier_learners.append(pickle.load(vgf))
if self.textual_trasformer_vgf:
if self.textual_trf_vgf:
with open(
os.path.join(
"models", "vgfs", "transformer", f"transformerGen_{model_id}.pkl"