todo updates
This commit is contained in:
parent
41647f974a
commit
ab7a310b34
|
@ -118,7 +118,6 @@ class gFunDataset:
|
|||
|
||||
if self.data_langs is None:
|
||||
data_langs = sorted(train_split.geo.unique().tolist())
|
||||
# TODO: if data langs is NOT none then we have a problem where we filter df by langs
|
||||
if self.labels is None:
|
||||
labels = train_split.category_name.unique().tolist()
|
||||
|
||||
|
|
|
@ -82,7 +82,7 @@ class GeneralizedFunnelling:
|
|||
self.aggfunc = aggfunc
|
||||
self.load_trained = load_trained
|
||||
self.load_first_tier = (
|
||||
True # TODO: i guess we're always going to load at least the fitst tier
|
||||
True # TODO: i guess we're always going to load at least the first tier
|
||||
)
|
||||
self.load_meta = load_meta
|
||||
self.dataset_name = dataset_name
|
||||
|
|
|
@ -104,10 +104,6 @@ class MultilingualGen(ViewGen):
|
|||
pickle.dump(self, f)
|
||||
return self
|
||||
|
||||
def __str__(self):
|
||||
_str = f"[Multilingual VGF (m)]\n- embed_dir: {self.embed_dir}\n- langs: {self.langs}\n- n_jobs: {self.n_jobs}\n- cached: {self.cached}\n- sif: {self.sif}\n- probabilistic: {self.probabilistic}\n"
|
||||
return _str
|
||||
|
||||
|
||||
def load_MUSEs(langs, l_vocab, dir_path, cached=False):
|
||||
dir_path = expanduser(dir_path)
|
||||
|
|
|
@ -193,7 +193,7 @@ class TextualTransformerGen(ViewGen, TransformerGen):
|
|||
)
|
||||
trainer.train(
|
||||
train_dataloader=tra_dataloader,
|
||||
eval_dataloader=val_dataloader, # TODO: debug setting
|
||||
eval_dataloader=val_dataloader,
|
||||
epochs=self.epochs,
|
||||
)
|
||||
|
||||
|
@ -275,10 +275,6 @@ class TextualTransformerGen(ViewGen, TransformerGen):
|
|||
else:
|
||||
return model_name
|
||||
|
||||
def __str__(self):
|
||||
str = f"[Transformer VGF (t)]\n- model_name: {self.model_name}\n- max_length: {self.max_length}\n- batch_size: {self.batch_size}\n- batch_size_eval: {self.batch_size_eval}\n- lr: {self.lr}\n- epochs: {self.epochs}\n- device: {self.device}\n- print_steps: {self.print_steps}\n- evaluate_step: {self.evaluate_step}\n- patience: {self.patience}\n- probabilistic: {self.probabilistic}\n"
|
||||
return str
|
||||
|
||||
def get_config(self):
|
||||
c = super().get_config()
|
||||
return {"textual_trf": c}
|
||||
|
|
|
@ -65,8 +65,3 @@ class VanillaFunGen(ViewGen):
|
|||
with open(_path, "wb") as f:
|
||||
pickle.dump(self, f)
|
||||
return self
|
||||
|
||||
def __str__(self):
|
||||
_str = f"[VanillaFunGen (-p)]\n- base learner: {self.learners}\n- n_jobs: {self.n_jobs}\n"
|
||||
# - parameters: {self.first_tier_parameters}
|
||||
return _str
|
||||
|
|
|
@ -185,9 +185,5 @@ class VisualTransformerGen(ViewGen, TransformerGen):
|
|||
pickle.dump(self, f)
|
||||
return self
|
||||
|
||||
def __str__(self):
|
||||
str = f"[Visual Transformer VGF (v)]\n- model_name: {self.model_name}\n- batch_size: {self.batch_size}\n- batch_size_eval: {self.batch_size_eval}\n- lr: {self.lr}\n- epochs: {self.epochs}\n- device: {self.device}\n- print_steps: {self.print_steps}\n- evaluate_step: {self.evaluate_step}\n- patience: {self.patience}\n- probabilistic: {self.probabilistic}\n"
|
||||
return str
|
||||
|
||||
def get_config(self):
|
||||
return {"visual_trf": super().get_config()}
|
||||
|
|
|
@ -40,10 +40,6 @@ class WceGen(ViewGen):
|
|||
"sif": self.sif,
|
||||
}
|
||||
|
||||
def __str__(self):
|
||||
_str = f"[WordClass VGF (w)]\n- sif: {self.sif}\n- n_jobs: {self.n_jobs}\n"
|
||||
return _str
|
||||
|
||||
def save_vgf(self, model_id):
|
||||
import pickle
|
||||
from os.path import join
|
||||
|
|
Loading…
Reference in New Issue