better way to save/load model via id ({config}_{date}); Implemented __str__ for each VGFs + get_config in GeneralizedFunnelling
This commit is contained in:
parent
31fb436cf0
commit
19e4f294db
|
@ -13,8 +13,6 @@ from vgfs.transformerGen import TransformerGen
|
||||||
from vgfs.vanillaFun import VanillaFunGen
|
from vgfs.vanillaFun import VanillaFunGen
|
||||||
from vgfs.wceGen import WceGen
|
from vgfs.wceGen import WceGen
|
||||||
|
|
||||||
# TODO: save and load gfun model
|
|
||||||
|
|
||||||
|
|
||||||
class GeneralizedFunnelling:
|
class GeneralizedFunnelling:
|
||||||
def __init__(
|
def __init__(
|
||||||
|
@ -37,7 +35,7 @@ class GeneralizedFunnelling:
|
||||||
device,
|
device,
|
||||||
load_trained,
|
load_trained,
|
||||||
):
|
):
|
||||||
# Forcing VFGs -----------
|
# Setting VFGs -----------
|
||||||
self.posteriors_vgf = posterior
|
self.posteriors_vgf = posterior
|
||||||
self.wce_vgf = wce
|
self.wce_vgf = wce
|
||||||
self.multilingual_vgf = multilingual
|
self.multilingual_vgf = multilingual
|
||||||
|
@ -69,9 +67,11 @@ class GeneralizedFunnelling:
|
||||||
|
|
||||||
def _init(self):
|
def _init(self):
|
||||||
print("[Init GeneralizedFunnelling]")
|
print("[Init GeneralizedFunnelling]")
|
||||||
if self.load_trained:
|
if self.load_trained is not None:
|
||||||
print("- loading trained VGFs, metaclassifer and vectorizer")
|
print("- loading trained VGFs, metaclassifer and vectorizer")
|
||||||
self.first_tier_learners, self.metaclassifier, self.vectorizer = self.load()
|
self.first_tier_learners, self.metaclassifier, self.vectorizer = self.load(
|
||||||
|
self.load_trained
|
||||||
|
)
|
||||||
# TODO: config like aggfunc, device, n_jobs, etc
|
# TODO: config like aggfunc, device, n_jobs, etc
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
@ -118,6 +118,14 @@ class GeneralizedFunnelling:
|
||||||
meta_parameters=get_params(self.optimc),
|
meta_parameters=get_params(self.optimc),
|
||||||
n_jobs=self.n_jobs,
|
n_jobs=self.n_jobs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self._model_id = get_unique_id(
|
||||||
|
self.posteriors_vgf,
|
||||||
|
self.multilingual_vgf,
|
||||||
|
self.wce_vgf,
|
||||||
|
self.trasformer_vgf,
|
||||||
|
)
|
||||||
|
print(f"- model id: {self._model_id}")
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def init_vgfs_vectorizers(self):
|
def init_vgfs_vectorizers(self):
|
||||||
|
@ -127,12 +135,12 @@ class GeneralizedFunnelling:
|
||||||
|
|
||||||
def fit(self, lX, lY):
|
def fit(self, lX, lY):
|
||||||
print("[Fitting GeneralizedFunnelling]")
|
print("[Fitting GeneralizedFunnelling]")
|
||||||
if self.load_trained:
|
if self.load_trained is not None:
|
||||||
print(f"- loaded trained model! Skipping training...")
|
print(f"- loaded trained model! Skipping training...")
|
||||||
load_only_first_tier = False # TODO
|
# TODO: add support to load only the first tier learners while re-training the metaclassifier
|
||||||
|
load_only_first_tier = False
|
||||||
if load_only_first_tier:
|
if load_only_first_tier:
|
||||||
projections = []
|
raise NotImplementedError
|
||||||
# TODO project, aggregate and fit the metaclassifier
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
self.vectorizer.fit(lX)
|
self.vectorizer.fit(lX)
|
||||||
|
@ -169,7 +177,6 @@ class GeneralizedFunnelling:
|
||||||
return aggregated
|
return aggregated
|
||||||
|
|
||||||
def _aggregate_mean(self, first_tier_projections):
|
def _aggregate_mean(self, first_tier_projections):
|
||||||
# TODO: deafult dict for one-liner?
|
|
||||||
aggregated = {
|
aggregated = {
|
||||||
lang: np.zeros(data.shape)
|
lang: np.zeros(data.shape)
|
||||||
for lang, data in first_tier_projections[0].items()
|
for lang, data in first_tier_projections[0].items()
|
||||||
|
@ -185,55 +192,75 @@ class GeneralizedFunnelling:
|
||||||
return aggregated
|
return aggregated
|
||||||
|
|
||||||
def get_config(self):
|
def get_config(self):
|
||||||
from pprint import pprint
|
print("\n")
|
||||||
|
print("-" * 50)
|
||||||
# TODO
|
|
||||||
print("[GeneralizedFunnelling config]")
|
print("[GeneralizedFunnelling config]")
|
||||||
print(f"- langs: {self.langs}")
|
print(f"- model trained on langs: {self.langs}")
|
||||||
print("-- vgfs:")
|
print("-- View Generating Functions configurations:\n")
|
||||||
|
|
||||||
for vgf in self.first_tier_learners:
|
for vgf in self.first_tier_learners:
|
||||||
pprint(vgf.get_config())
|
print(vgf)
|
||||||
|
print("-" * 50)
|
||||||
|
|
||||||
def save(self):
|
def save(self):
|
||||||
|
print(f"- Saving GeneralizedFunnelling model with id: {self._model_id}")
|
||||||
|
# TODO: save only the first tier learners? what about some model config + sanity checks before loading?
|
||||||
for vgf in self.first_tier_learners:
|
for vgf in self.first_tier_learners:
|
||||||
vgf.save_vgf()
|
vgf.save_vgf(model_id=self._model_id)
|
||||||
# Saving metaclassifier
|
os.makedirs(os.path.join("models", "metaclassifier"), exist_ok=True)
|
||||||
with open(os.path.join("models", "metaclassifier", "meta_todo.pkl"), "wb") as f:
|
|
||||||
pickle.dump(self.metaclassifier, f)
|
|
||||||
# Saving vectorizer
|
|
||||||
with open(
|
with open(
|
||||||
os.path.join("models", "vectorizer", "vectorizer_todo.pkl"), "wb"
|
os.path.join("models", "metaclassifier", f"meta_{self._model_id}.pkl"), "wb"
|
||||||
|
) as f:
|
||||||
|
pickle.dump(self.metaclassifier, f)
|
||||||
|
os.makedirs(os.path.join("models", "vectorizer"), exist_ok=True)
|
||||||
|
with open(
|
||||||
|
os.path.join("models", "vectorizer", f"vectorizer_{self._model_id}.pkl"),
|
||||||
|
"wb",
|
||||||
) as f:
|
) as f:
|
||||||
pickle.dump(self.vectorizer, f)
|
pickle.dump(self.vectorizer, f)
|
||||||
# TODO: save some config and perform sanity checks?
|
|
||||||
return
|
return
|
||||||
|
|
||||||
def load(self):
|
def load(self, model_id):
|
||||||
|
print(f"- loading model id: {model_id}")
|
||||||
first_tier_learners = []
|
first_tier_learners = []
|
||||||
if self.posteriors_vgf:
|
if self.posteriors_vgf:
|
||||||
# FIXME: hardcoded
|
|
||||||
with open(
|
with open(
|
||||||
os.path.join("models", "vgfs", "posteriors", "vanillaFunGen_todo.pkl"),
|
os.path.join(
|
||||||
|
"models", "vgfs", "posterior", f"vanillaFunGen_{model_id}.pkl"
|
||||||
|
),
|
||||||
"rb",
|
"rb",
|
||||||
) as vgf:
|
) as vgf:
|
||||||
first_tier_learners.append(pickle.load(vgf))
|
first_tier_learners.append(pickle.load(vgf))
|
||||||
if self.multilingual_vgf:
|
if self.multilingual_vgf:
|
||||||
# FIXME: hardcoded
|
with open(
|
||||||
with open("models/vgfs/multilingual/vanillaFunGen_todo.pkl") as vgf:
|
os.path.join(
|
||||||
|
"models", "vgfs", "multilingual", f"multilingualGen_{model_id}.pkl"
|
||||||
|
),
|
||||||
|
"rb",
|
||||||
|
) as vgf:
|
||||||
first_tier_learners.append(pickle.load(vgf))
|
first_tier_learners.append(pickle.load(vgf))
|
||||||
if self.wce_vgf:
|
if self.wce_vgf:
|
||||||
# FIXME: hardcoded
|
with open(
|
||||||
with open("models/vgfs/wordclass/vanillaFunGen_todo.pkl") as vgf:
|
os.path.join(
|
||||||
|
"models", "vgfs", "wordclass", f"wordClassGen_{model_id}.pkl"
|
||||||
|
),
|
||||||
|
"rb",
|
||||||
|
) as vgf:
|
||||||
first_tier_learners.append(pickle.load(vgf))
|
first_tier_learners.append(pickle.load(vgf))
|
||||||
if self.trasformer_vgf:
|
if self.trasformer_vgf:
|
||||||
# FIXME: hardcoded
|
with open(
|
||||||
with open("models/vgfs/transformers/vanillaFunGen_todo.pkl") as vgf:
|
os.path.join(
|
||||||
|
"models", "vgfs", "transformer", f"transformerGen_{model_id}.pkl"
|
||||||
|
),
|
||||||
|
"rb",
|
||||||
|
) as vgf:
|
||||||
first_tier_learners.append(pickle.load(vgf))
|
first_tier_learners.append(pickle.load(vgf))
|
||||||
with open(os.path.join("models", "metaclassifier", "meta_todo.pkl"), "rb") as f:
|
with open(
|
||||||
|
os.path.join("models", "metaclassifier", f"meta_{model_id}.pkl"), "rb"
|
||||||
|
) as f:
|
||||||
metaclassifier = pickle.load(f)
|
metaclassifier = pickle.load(f)
|
||||||
with open(
|
with open(
|
||||||
os.path.join("models", "vectorizer", "vectorizer_todo.pkl"), "rb"
|
os.path.join("models", "vectorizer", f"vectorizer_{model_id}.pkl"), "rb"
|
||||||
) as f:
|
) as f:
|
||||||
vectorizer = pickle.load(f)
|
vectorizer = pickle.load(f)
|
||||||
return first_tier_learners, metaclassifier, vectorizer
|
return first_tier_learners, metaclassifier, vectorizer
|
||||||
|
@ -245,3 +272,15 @@ def get_params(optimc=False):
|
||||||
c_range = [1e4, 1e3, 1e2, 1e1, 1, 1e-1]
|
c_range = [1e4, 1e3, 1e2, 1e1, 1, 1e-1]
|
||||||
kernel = "rbf"
|
kernel = "rbf"
|
||||||
return [{"kernel": [kernel], "C": c_range, "gamma": ["auto"]}]
|
return [{"kernel": [kernel], "C": c_range, "gamma": ["auto"]}]
|
||||||
|
|
||||||
|
|
||||||
|
def get_unique_id(posterior, multilingual, wce, transformer):
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
now = datetime.now().strftime("%y%m%d")
|
||||||
|
model_id = ""
|
||||||
|
model_id += "p" if posterior else ""
|
||||||
|
model_id += "m" if multilingual else ""
|
||||||
|
model_id += "w" if wce else ""
|
||||||
|
model_id += "t" if transformer else ""
|
||||||
|
return f"{model_id}_{now}"
|
||||||
|
|
|
@ -73,7 +73,6 @@ class MonolingualClassifier:
|
||||||
|
|
||||||
# parameter optimization?
|
# parameter optimization?
|
||||||
if self.parameters:
|
if self.parameters:
|
||||||
print("debug: optimizing parameters:", self.parameters)
|
|
||||||
self.model = GridSearchCV(
|
self.model = GridSearchCV(
|
||||||
self.model,
|
self.model,
|
||||||
param_grid=self.parameters,
|
param_grid=self.parameters,
|
||||||
|
@ -81,7 +80,7 @@ class MonolingualClassifier:
|
||||||
cv=5,
|
cv=5,
|
||||||
n_jobs=self.n_jobs,
|
n_jobs=self.n_jobs,
|
||||||
error_score=0,
|
error_score=0,
|
||||||
verbose=10,
|
verbose=1,
|
||||||
)
|
)
|
||||||
|
|
||||||
# print(f"-- Fitting learner on matrices X={X.shape} Y={y.shape}")
|
# print(f"-- Fitting learner on matrices X={X.shape} Y={y.shape}")
|
||||||
|
|
|
@ -91,6 +91,23 @@ class MultilingualGen(ViewGen):
|
||||||
"probabilistic": self.probabilistic,
|
"probabilistic": self.probabilistic,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def save_vgf(self, model_id):
|
||||||
|
import pickle
|
||||||
|
from os.path import join
|
||||||
|
from os import makedirs
|
||||||
|
|
||||||
|
vgf_name = "multilingualGen"
|
||||||
|
_basedir = join("models", "vgfs", "multilingual")
|
||||||
|
makedirs(_basedir, exist_ok=True)
|
||||||
|
_path = join(_basedir, f"{vgf_name}_{model_id}.pkl")
|
||||||
|
with open(_path, "wb") as f:
|
||||||
|
pickle.dump(self, f)
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
_str = f"[Multilingual VGF (m)]\n- embed_dir: {self.embed_dir}\n- langs: {self.langs}\n- n_jobs: {self.n_jobs}\n- cached: {self.cached}\n- sif: {self.sif}\n- probabilistic: {self.probabilistic}\n"
|
||||||
|
return _str
|
||||||
|
|
||||||
|
|
||||||
def load_MUSEs(langs, l_vocab, dir_path, cached=False):
|
def load_MUSEs(langs, l_vocab, dir_path, cached=False):
|
||||||
dir_path = expanduser(dir_path)
|
dir_path = expanduser(dir_path)
|
||||||
|
|
|
@ -18,8 +18,7 @@ from evaluation.evaluate import evaluate, log_eval
|
||||||
transformers.logging.set_verbosity_error()
|
transformers.logging.set_verbosity_error()
|
||||||
|
|
||||||
|
|
||||||
# TODO: early stopping, checkpointing, logging, model loading
|
# TODO: add support to loggers
|
||||||
# TODO: experiment name
|
|
||||||
|
|
||||||
|
|
||||||
class TransformerGen:
|
class TransformerGen:
|
||||||
|
@ -29,6 +28,7 @@ class TransformerGen:
|
||||||
epochs=10,
|
epochs=10,
|
||||||
lr=1e-5,
|
lr=1e-5,
|
||||||
batch_size=4,
|
batch_size=4,
|
||||||
|
batch_size_eval=32,
|
||||||
max_length=512,
|
max_length=512,
|
||||||
print_steps=50,
|
print_steps=50,
|
||||||
device="cpu",
|
device="cpu",
|
||||||
|
@ -46,6 +46,7 @@ class TransformerGen:
|
||||||
self.tokenizer = None
|
self.tokenizer = None
|
||||||
self.max_length = max_length
|
self.max_length = max_length
|
||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
|
self.batch_size_eval = batch_size_eval
|
||||||
self.print_steps = print_steps
|
self.print_steps = print_steps
|
||||||
self.probabilistic = probabilistic
|
self.probabilistic = probabilistic
|
||||||
self.n_jobs = n_jobs
|
self.n_jobs = n_jobs
|
||||||
|
@ -137,10 +138,10 @@ class TransformerGen:
|
||||||
)
|
)
|
||||||
|
|
||||||
val_dataloader = self.build_dataloader(
|
val_dataloader = self.build_dataloader(
|
||||||
val_lX, val_lY, self.batch_size, split="val", shuffle=False
|
val_lX, val_lY, self.batch_size_eval, split="val", shuffle=False
|
||||||
)
|
)
|
||||||
|
|
||||||
experiment_name = f"{self.model_name}-{self.epochs}-{self.batch_size}" # TODO: add more params
|
experiment_name = f"{self.model_name}-{self.epochs}-{self.batch_size}"
|
||||||
trainer = Trainer(
|
trainer = Trainer(
|
||||||
model=self.model,
|
model=self.model,
|
||||||
optimizer_name="adamW",
|
optimizer_name="adamW",
|
||||||
|
@ -163,8 +164,6 @@ class TransformerGen:
|
||||||
|
|
||||||
self.fitted = True
|
self.fitted = True
|
||||||
|
|
||||||
# self.save_vgf(path="models/vgf/transformers/")
|
|
||||||
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def transform(self, lX):
|
def transform(self, lX):
|
||||||
|
@ -172,7 +171,7 @@ class TransformerGen:
|
||||||
l_embeds = defaultdict(list)
|
l_embeds = defaultdict(list)
|
||||||
|
|
||||||
dataloader = self.build_dataloader(
|
dataloader = self.build_dataloader(
|
||||||
lX, lY=None, batch_size=self.batch_size, split="whole", shuffle=False
|
lX, lY=None, batch_size=self.batch_size_eval, split="whole", shuffle=False
|
||||||
)
|
)
|
||||||
|
|
||||||
self.model.eval()
|
self.model.eval()
|
||||||
|
@ -195,24 +194,23 @@ class TransformerGen:
|
||||||
def fit_transform(self, lX, lY):
|
def fit_transform(self, lX, lY):
|
||||||
return self.fit(lX, lY).transform(lX)
|
return self.fit(lX, lY).transform(lX)
|
||||||
|
|
||||||
def save_vgf(self, path):
|
def save_vgf(self, model_id):
|
||||||
print(f"- saving Transformer View Generating Function to {path}")
|
import pickle
|
||||||
return
|
from os.path import join
|
||||||
|
from os import makedirs
|
||||||
|
|
||||||
|
vgf_name = "transformerGen"
|
||||||
|
_basedir = join("models", "vgfs", "transformer")
|
||||||
|
makedirs(_basedir, exist_ok=True)
|
||||||
|
_path = join(_basedir, f"{vgf_name}_{model_id}.pkl")
|
||||||
|
with open(_path, "wb") as f:
|
||||||
|
pickle.dump(self, f)
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
str = f"[Transformer VGF (t)]\n- model_name: {self.model_name}\n- max_length: {self.max_length}\n- batch_size: {self.batch_size}\n- batch_size_eval: {self.batch_size_eval}\n- lr: {self.lr}\n- epochs: {self.epochs}\n- device: {self.device}\n- print_steps: {self.print_steps}\n- evaluate_step: {self.evaluate_step}\n- patience: {self.patience}\n- probabilistic: {self.probabilistic}\n"
|
||||||
|
return str
|
||||||
|
|
||||||
def get_config(self):
|
|
||||||
return {
|
|
||||||
"name": "Transformer VGF",
|
|
||||||
"model_name": self.model_name,
|
|
||||||
"max_length": self.max_length,
|
|
||||||
"batch_size": self.batch_size,
|
|
||||||
"lr": self.lr,
|
|
||||||
"epochs": self.epochs,
|
|
||||||
"device": self.device,
|
|
||||||
"print_steps": self.print_steps,
|
|
||||||
"evaluate_step": self.evaluate_step,
|
|
||||||
"patience": self.patience,
|
|
||||||
"probabilistic": self.probabilistic,
|
|
||||||
}
|
|
||||||
|
|
||||||
class MultilingualDatasetTorch(Dataset):
|
class MultilingualDatasetTorch(Dataset):
|
||||||
def __init__(self, lX, lY, split="train"):
|
def __init__(self, lX, lY, split="train"):
|
||||||
|
@ -285,7 +283,7 @@ class Trainer:
|
||||||
- epochs: {epochs}
|
- epochs: {epochs}
|
||||||
- learning rate: {self.optimizer.defaults['lr']}
|
- learning rate: {self.optimizer.defaults['lr']}
|
||||||
- train batch size: {train_dataloader.batch_size}
|
- train batch size: {train_dataloader.batch_size}
|
||||||
- eval batch size: {'TODO'}
|
- eval batch size: {eval_dataloader.batch_size}
|
||||||
- max len: {train_dataloader.dataset.X.shape[-1]}\n""",
|
- max len: {train_dataloader.dataset.X.shape[-1]}\n""",
|
||||||
)
|
)
|
||||||
for epoch in range(epochs):
|
for epoch in range(epochs):
|
||||||
|
|
|
@ -55,25 +55,20 @@ class VanillaFunGen(ViewGen):
|
||||||
def fit_transform(self, lX, lY):
|
def fit_transform(self, lX, lY):
|
||||||
return self.fit(lX, lY).transform(lX)
|
return self.fit(lX, lY).transform(lX)
|
||||||
|
|
||||||
def get_config(self):
|
def save_vgf(self, model_id):
|
||||||
return {
|
|
||||||
"name": "VanillaFunnelling VGF",
|
|
||||||
"base_learner": self.learners,
|
|
||||||
"first_tier_parameters": self.first_tier_parameters,
|
|
||||||
"n_jobs": self.n_jobs,
|
|
||||||
}
|
|
||||||
|
|
||||||
def save_vgf(self):
|
|
||||||
import pickle
|
import pickle
|
||||||
from os.path import join
|
from os.path import join
|
||||||
from os import makedirs
|
from os import makedirs
|
||||||
|
|
||||||
model_id = "TODO"
|
vgf_name = "vanillaFunGen"
|
||||||
|
_basedir = join("models", "vgfs", "posterior")
|
||||||
vgf_name = "vanillaFunGen_todo"
|
|
||||||
_basedir = join("models", "vgfs", "posteriors")
|
|
||||||
makedirs(_basedir, exist_ok=True)
|
makedirs(_basedir, exist_ok=True)
|
||||||
_path = join(_basedir, f"{vgf_name}.pkl")
|
_path = join(_basedir, f"{vgf_name}_{model_id}.pkl")
|
||||||
with open(_path, "wb") as f:
|
with open(_path, "wb") as f:
|
||||||
pickle.dump(self, f)
|
pickle.dump(self, f)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
_str = f"[VanillaFunGen (-p)]\n- base learner: {self.learners}\n- n_jobs: {self.n_jobs}\n"
|
||||||
|
# - parameters: {self.first_tier_parameters}
|
||||||
|
return _str
|
||||||
|
|
|
@ -18,3 +18,7 @@ class ViewGen(ABC):
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def fit_transform(self, lX, lY):
|
def fit_transform(self, lX, lY):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def save_vgf(self, model_id):
|
||||||
|
pass
|
||||||
|
|
|
@ -40,6 +40,23 @@ class WceGen(ViewGen):
|
||||||
"sif": self.sif,
|
"sif": self.sif,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
_str = f"[WordClass VGF (w)]\n- sif: {self.sif}\n- n_jobs: {self.n_jobs}\n"
|
||||||
|
return _str
|
||||||
|
|
||||||
|
def save_vgf(self, model_id):
|
||||||
|
import pickle
|
||||||
|
from os.path import join
|
||||||
|
from os import makedirs
|
||||||
|
|
||||||
|
vgf_name = "wordClassGen"
|
||||||
|
_basedir = join("models", "vgfs", "wordclass")
|
||||||
|
makedirs(_basedir, exist_ok=True)
|
||||||
|
_path = join(_basedir, f"{vgf_name}_{model_id}.pkl")
|
||||||
|
with open(_path, "wb") as f:
|
||||||
|
pickle.dump(self, f)
|
||||||
|
return self
|
||||||
|
|
||||||
|
|
||||||
def wce_matrix(X, Y):
|
def wce_matrix(X, Y):
|
||||||
wce = supervised_embeddings_tfidf(X, Y)
|
wce = supervised_embeddings_tfidf(X, Y)
|
||||||
|
|
11
main.py
11
main.py
|
@ -11,11 +11,8 @@ from gfun.generalizedFunnelling import GeneralizedFunnelling
|
||||||
|
|
||||||
"""
|
"""
|
||||||
TODO:
|
TODO:
|
||||||
- a cleaner way to save the model? each VGF saved independently (together with
|
|
||||||
standardizer and feature2posteriors). What about the metaclassifier and the vectorizers?
|
|
||||||
- add documentations sphinx
|
- add documentations sphinx
|
||||||
- zero-shot setup
|
- zero-shot setup
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
@ -42,7 +39,7 @@ def main(args):
|
||||||
|
|
||||||
tinit = time()
|
tinit = time()
|
||||||
|
|
||||||
if not args.load_trained:
|
if args.load_trained is None:
|
||||||
assert any(
|
assert any(
|
||||||
[
|
[
|
||||||
args.posteriors,
|
args.posteriors,
|
||||||
|
@ -73,8 +70,12 @@ def main(args):
|
||||||
load_trained=args.load_trained,
|
load_trained=args.load_trained,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# gfun.get_config()
|
||||||
gfun.fit(lX, lY)
|
gfun.fit(lX, lY)
|
||||||
|
|
||||||
|
if args.load_trained is None:
|
||||||
|
gfun.save()
|
||||||
|
|
||||||
# if not args.load_model:
|
# if not args.load_model:
|
||||||
# gfun.save()
|
# gfun.save()
|
||||||
|
|
||||||
|
@ -95,7 +96,7 @@ def main(args):
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = ArgumentParser()
|
parser = ArgumentParser()
|
||||||
parser.add_argument("-l", "--load_trained", action="store_true")
|
parser.add_argument("-l", "--load_trained", type=str, default=None)
|
||||||
# Dataset parameters -------------------
|
# Dataset parameters -------------------
|
||||||
parser.add_argument("--domains", type=str, default="all")
|
parser.add_argument("--domains", type=str, default="all")
|
||||||
parser.add_argument("--nrows", type=int, default=10000)
|
parser.add_argument("--nrows", type=int, default=10000)
|
||||||
|
|
Loading…
Reference in New Issue