fixed loading function for Attention-based aggregating function when triggered by EarlyStopper
This commit is contained in:
parent
930a6d8275
commit
7ed98346a5
|
@ -171,6 +171,9 @@ class MultilingualDataset:
|
|||
else:
|
||||
langs = sorted(self.multiling_dataset.keys())
|
||||
return langs
|
||||
|
||||
def num_labels(self):
|
||||
return self.num_categories()
|
||||
|
||||
def num_categories(self):
|
||||
return self.lYtr()[self.langs()[0]].shape[1]
|
||||
|
|
|
@ -22,6 +22,7 @@ class GeneralizedFunnelling:
|
|||
multilingual,
|
||||
transformer,
|
||||
langs,
|
||||
num_labels,
|
||||
embed_dir,
|
||||
n_jobs,
|
||||
batch_size,
|
||||
|
@ -37,6 +38,7 @@ class GeneralizedFunnelling:
|
|||
dataset_name,
|
||||
probabilistic,
|
||||
aggfunc,
|
||||
load_meta,
|
||||
):
|
||||
# Setting VFGs -----------
|
||||
self.posteriors_vgf = posterior
|
||||
|
@ -44,7 +46,7 @@ class GeneralizedFunnelling:
|
|||
self.multilingual_vgf = multilingual
|
||||
self.trasformer_vgf = transformer
|
||||
self.probabilistic = probabilistic
|
||||
self.num_labels = 73 # TODO: hard-coded
|
||||
self.num_labels = num_labels
|
||||
# ------------------------
|
||||
self.langs = langs
|
||||
self.embed_dir = embed_dir
|
||||
|
@ -68,6 +70,10 @@ class GeneralizedFunnelling:
|
|||
self.metaclassifier = None
|
||||
self.aggfunc = aggfunc
|
||||
self.load_trained = load_trained
|
||||
self.load_first_tier = (
|
||||
True # TODO: i guess we're always going to load at least the fitst tier
|
||||
)
|
||||
self.load_meta = load_meta
|
||||
self.dataset_name = dataset_name
|
||||
self._init()
|
||||
|
||||
|
@ -77,11 +83,37 @@ class GeneralizedFunnelling:
|
|||
self.aggfunc == "mean" and self.probabilistic is False
|
||||
), "When using averaging aggreagation function probabilistic must be True"
|
||||
if self.load_trained is not None:
|
||||
print("- loading trained VGFs, metaclassifer and vectorizer")
|
||||
self.first_tier_learners, self.metaclassifier, self.vectorizer = self.load(
|
||||
self.load_trained
|
||||
# TODO: clean up this code here
|
||||
print(
|
||||
"- loading trained VGFs, metaclassifer and vectorizer"
|
||||
if self.load_meta
|
||||
else "- loading trained VGFs and vectorizer"
|
||||
)
|
||||
# TODO: config like aggfunc, device, n_jobs, etc
|
||||
self.first_tier_learners, self.metaclassifier, self.vectorizer = self.load(
|
||||
self.load_trained,
|
||||
load_first_tier=self.load_first_tier,
|
||||
load_meta=self.load_meta,
|
||||
)
|
||||
if self.metaclassifier is None:
|
||||
self.metaclassifier = MetaClassifier(
|
||||
meta_learner=get_learner(calibrate=True, kernel="rbf"),
|
||||
meta_parameters=get_params(self.optimc),
|
||||
n_jobs=self.n_jobs,
|
||||
)
|
||||
|
||||
if "attn" in self.aggfunc:
|
||||
attn_stacking = self.aggfunc.split("_")[1]
|
||||
self.attn_aggregator = AttentionAggregator(
|
||||
embed_dim=self.get_attn_agg_dim(attn_stacking_type=attn_stacking),
|
||||
out_dim=self.num_labels,
|
||||
lr=self.lr_transformer,
|
||||
patience=self.patience,
|
||||
num_heads=1,
|
||||
device=self.device,
|
||||
epochs=self.epochs,
|
||||
attn_stacking_type=attn_stacking,
|
||||
)
|
||||
return self
|
||||
|
||||
if self.posteriors_vgf:
|
||||
fun = VanillaFunGen(
|
||||
|
@ -112,7 +144,7 @@ class GeneralizedFunnelling:
|
|||
epochs=self.epochs,
|
||||
batch_size=self.batch_size_transformer,
|
||||
max_length=self.max_length,
|
||||
device="cuda",
|
||||
device=self.device,
|
||||
print_steps=50,
|
||||
probabilistic=self.probabilistic,
|
||||
evaluate_step=self.evaluate_step,
|
||||
|
@ -121,13 +153,17 @@ class GeneralizedFunnelling:
|
|||
)
|
||||
self.first_tier_learners.append(transformer_vgf)
|
||||
|
||||
if self.aggfunc == "attn":
|
||||
if "attn" in self.aggfunc:
|
||||
attn_stacking = self.aggfunc.split("_")[1]
|
||||
self.attn_aggregator = AttentionAggregator(
|
||||
embed_dim=self.get_attn_agg_dim(),
|
||||
out_dim=self.num_labels,
|
||||
lr=self.lr_transformer,
|
||||
patience=self.patience,
|
||||
num_heads=1,
|
||||
device=self.device,
|
||||
epochs=self.epochs,
|
||||
attn_stacking_type=attn_stacking,
|
||||
)
|
||||
|
||||
self.metaclassifier = MetaClassifier(
|
||||
|
@ -141,6 +177,7 @@ class GeneralizedFunnelling:
|
|||
self.multilingual_vgf,
|
||||
self.wce_vgf,
|
||||
self.trasformer_vgf,
|
||||
self.aggfunc,
|
||||
)
|
||||
print(f"- model id: {self._model_id}")
|
||||
return self
|
||||
|
@ -153,11 +190,19 @@ class GeneralizedFunnelling:
|
|||
def fit(self, lX, lY):
|
||||
print("[Fitting GeneralizedFunnelling]")
|
||||
if self.load_trained is not None:
|
||||
print(f"- loaded trained model! Skipping training...")
|
||||
# TODO: add support to load only the first tier learners while re-training the metaclassifier
|
||||
load_only_first_tier = False
|
||||
if load_only_first_tier:
|
||||
raise NotImplementedError
|
||||
print(
|
||||
"- loaded first tier learners!"
|
||||
if self.load_meta is False
|
||||
else "- loaded trained model!"
|
||||
)
|
||||
if self.load_first_tier is True and self.load_meta is False:
|
||||
# TODO: clean up this code here
|
||||
projections = []
|
||||
for vgf in self.first_tier_learners:
|
||||
l_posteriors = vgf.transform(lX)
|
||||
projections.append(l_posteriors)
|
||||
agg = self.aggregate(projections, lY)
|
||||
self.metaclassifier.fit(agg, lY)
|
||||
return self
|
||||
|
||||
self.vectorizer.fit(lX)
|
||||
|
@ -191,7 +236,8 @@ class GeneralizedFunnelling:
|
|||
aggregated = self._aggregate_mean(first_tier_projections)
|
||||
elif self.aggfunc == "concat":
|
||||
aggregated = self._aggregate_concat(first_tier_projections)
|
||||
elif self.aggfunc == "attn":
|
||||
# elif self.aggfunc == "attn":
|
||||
elif "attn" in self.aggfunc:
|
||||
aggregated = self._aggregate_attn(first_tier_projections, lY)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
@ -238,27 +284,41 @@ class GeneralizedFunnelling:
|
|||
print(vgf)
|
||||
print("-" * 50)
|
||||
|
||||
def save(self):
|
||||
def save(self, save_first_tier=True, save_meta=True):
|
||||
print(f"- Saving GeneralizedFunnelling model with id: {self._model_id}")
|
||||
# TODO: save only the first tier learners? what about some model config + sanity checks before loading?
|
||||
for vgf in self.first_tier_learners:
|
||||
vgf.save_vgf(model_id=self._model_id)
|
||||
os.makedirs(os.path.join("models", "metaclassifier"), exist_ok=True)
|
||||
with open(
|
||||
os.path.join("models", "metaclassifier", f"meta_{self._model_id}.pkl"), "wb"
|
||||
) as f:
|
||||
pickle.dump(self.metaclassifier, f)
|
||||
|
||||
os.makedirs(os.path.join("models", "vectorizer"), exist_ok=True)
|
||||
with open(
|
||||
os.path.join("models", "vectorizer", f"vectorizer_{self._model_id}.pkl"),
|
||||
"wb",
|
||||
) as f:
|
||||
pickle.dump(self.vectorizer, f)
|
||||
|
||||
if save_first_tier:
|
||||
self.save_first_tier_learners(model_id=self._model_id)
|
||||
|
||||
if save_meta:
|
||||
with open(
|
||||
os.path.join("models", "metaclassifier", f"meta_{self._model_id}.pkl"),
|
||||
"wb",
|
||||
) as f:
|
||||
pickle.dump(self.metaclassifier, f)
|
||||
return
|
||||
|
||||
def load(self, model_id):
|
||||
def save_first_tier_learners(self, model_id):
|
||||
for vgf in self.first_tier_learners:
|
||||
vgf.save_vgf(model_id=self._model_id)
|
||||
return self
|
||||
|
||||
def load(self, model_id, load_first_tier=True, load_meta=True):
|
||||
print(f"- loading model id: {model_id}")
|
||||
first_tier_learners = []
|
||||
|
||||
with open(
|
||||
os.path.join("models", "vectorizer", f"vectorizer_{model_id}.pkl"), "rb"
|
||||
) as f:
|
||||
vectorizer = pickle.load(f)
|
||||
|
||||
if self.posteriors_vgf:
|
||||
with open(
|
||||
os.path.join(
|
||||
|
@ -291,20 +351,43 @@ class GeneralizedFunnelling:
|
|||
"rb",
|
||||
) as vgf:
|
||||
first_tier_learners.append(pickle.load(vgf))
|
||||
with open(
|
||||
os.path.join("models", "metaclassifier", f"meta_{model_id}.pkl"), "rb"
|
||||
) as f:
|
||||
metaclassifier = pickle.load(f)
|
||||
with open(
|
||||
os.path.join("models", "vectorizer", f"vectorizer_{model_id}.pkl"), "rb"
|
||||
) as f:
|
||||
vectorizer = pickle.load(f)
|
||||
|
||||
if load_meta:
|
||||
with open(
|
||||
os.path.join("models", "metaclassifier", f"meta_{model_id}.pkl"), "rb"
|
||||
) as f:
|
||||
metaclassifier = pickle.load(f)
|
||||
else:
|
||||
metaclassifier = None
|
||||
return first_tier_learners, metaclassifier, vectorizer
|
||||
|
||||
def get_attn_agg_dim(self):
|
||||
# TODO: hardcoded for now
|
||||
print("\n[NB: ATTN AGGREGATOR DIM HARD-CODED TO 146]\n")
|
||||
return 146
|
||||
def _load_meta(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def _load_posterior(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def _load_multilingual(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def _load_wce(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def _load_transformer(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def get_attn_agg_dim(self, attn_stacking_type="concat"):
|
||||
if self.probabilistic and "attn" not in self.aggfunc:
|
||||
return len(self.first_tier_learners) * self.num_labels
|
||||
elif self.probabilistic and "attn" in self.aggfunc:
|
||||
if attn_stacking_type == "concat":
|
||||
return len(self.first_tier_learners) * self.num_labels
|
||||
elif attn_stacking_type == "mean":
|
||||
return self.num_labels
|
||||
else:
|
||||
raise NotImplementedError
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def get_params(optimc=False):
|
||||
|
@ -315,7 +398,7 @@ def get_params(optimc=False):
|
|||
return [{"kernel": [kernel], "C": c_range, "gamma": ["auto"]}]
|
||||
|
||||
|
||||
def get_unique_id(posterior, multilingual, wce, transformer):
|
||||
def get_unique_id(posterior, multilingual, wce, transformer, aggfunc):
|
||||
from datetime import datetime
|
||||
|
||||
now = datetime.now().strftime("%y%m%d")
|
||||
|
@ -324,4 +407,5 @@ def get_unique_id(posterior, multilingual, wce, transformer):
|
|||
model_id += "m" if multilingual else ""
|
||||
model_id += "w" if wce else ""
|
||||
model_id += "t" if transformer else ""
|
||||
model_id += f"_{aggfunc}"
|
||||
return f"{model_id}_{now}"
|
||||
|
|
|
@ -10,6 +10,7 @@ from sklearn.feature_extraction.text import TfidfVectorizer
|
|||
from sklearn.preprocessing import normalize
|
||||
from torch.optim import AdamW
|
||||
from transformers.modeling_outputs import SequenceClassifierOutput
|
||||
from sklearn.model_selection import train_test_split
|
||||
|
||||
from evaluation.evaluate import evaluate, log_eval
|
||||
|
||||
|
@ -158,7 +159,6 @@ class Trainer:
|
|||
self.device
|
||||
)
|
||||
break
|
||||
# TODO: maybe a lower lr?
|
||||
self.train_epoch(eval_dataloader, epoch=epoch)
|
||||
print(f"\n- last swipe on eval set")
|
||||
self.earlystopping.save_model(self.model)
|
||||
|
@ -176,7 +176,7 @@ class Trainer:
|
|||
loss.backward()
|
||||
self.optimizer.step()
|
||||
if (epoch + 1) % PRINT_ON_EPOCH == 0:
|
||||
if b_idx % self.print_steps == 0:
|
||||
if ((b_idx + 1) % self.print_steps == 0) or b_idx == 0:
|
||||
print(f"Epoch: {epoch+1} Step: {b_idx+1} Loss: {loss:.4f}")
|
||||
return self
|
||||
|
||||
|
@ -209,7 +209,6 @@ class Trainer:
|
|||
|
||||
|
||||
class EarlyStopping:
|
||||
# TODO: add checkpointing + restore model if early stopping + last swipe on validation set
|
||||
def __init__(
|
||||
self,
|
||||
patience,
|
||||
|
@ -247,8 +246,8 @@ class EarlyStopping:
|
|||
return True
|
||||
|
||||
def save_model(self, model):
|
||||
os.makedirs(self.checkpoint_path, exist_ok=True)
|
||||
_checkpoint_dir = os.path.join(self.checkpoint_path, self.experiment_name)
|
||||
os.makedirs(_checkpoint_dir, exist_ok=True)
|
||||
model.save_pretrained(_checkpoint_dir)
|
||||
|
||||
def load_model(self, model):
|
||||
|
@ -257,51 +256,97 @@ class EarlyStopping:
|
|||
|
||||
|
||||
class AttentionModule(nn.Module):
|
||||
def __init__(self, embed_dim, num_heads, out_dim):
|
||||
def __init__(self, embed_dim, num_heads, h_dim, out_dim):
|
||||
super().__init__()
|
||||
self.attn = nn.MultiheadAttention(embed_dim, num_heads)
|
||||
self.attn = nn.MultiheadAttention(embed_dim, num_heads, dropout=0.1)
|
||||
self.layer_norm = nn.LayerNorm(embed_dim)
|
||||
self.linear = nn.Linear(embed_dim, out_dim)
|
||||
|
||||
def __call__(self, X):
|
||||
attn_out, attn_weights = self.attn(query=X, key=X, value=X)
|
||||
out = self.linear(attn_out)
|
||||
out, attn_weights = self.attn(query=X, key=X, value=X)
|
||||
out = self.layer_norm(out)
|
||||
out = self.linear(out)
|
||||
# out = self.sigmoid(out)
|
||||
return out
|
||||
# out = self.relu(out)
|
||||
# out = self.linear2(out)
|
||||
# out = self.sigmoid(out)
|
||||
|
||||
def transform(self, X):
|
||||
attn_out, attn_weights = self.attn(query=X, key=X, value=X)
|
||||
return attn_out
|
||||
return self.__call__(X)
|
||||
# out, attn_weights = self.attn(query=X, key=X, value=X)
|
||||
# out = self.layer_norm(out)
|
||||
# out = self.linear(out)
|
||||
# out = self.sigmoid(out)
|
||||
# return out
|
||||
# out = self.relu(out)
|
||||
# out = self.linear2(out)
|
||||
# out = self.sigmoid(out)
|
||||
|
||||
def save_pretrained(self, path):
|
||||
torch.save(self.state_dict(), f"{path}.pt")
|
||||
torch.save(self, f"{path}.pt")
|
||||
# torch.save(self.state_dict(), f"{path}.pt")
|
||||
|
||||
def _wtf(self):
|
||||
print("wtf")
|
||||
def from_pretrained(self, path):
|
||||
return torch.load(f"{path}.pt")
|
||||
|
||||
|
||||
class AttentionAggregator:
|
||||
def __init__(self, embed_dim, out_dim, epochs, num_heads=1, device="cpu"):
|
||||
def __init__(
|
||||
self,
|
||||
embed_dim,
|
||||
out_dim,
|
||||
epochs,
|
||||
lr,
|
||||
patience,
|
||||
attn_stacking_type,
|
||||
h_dim=512,
|
||||
num_heads=1,
|
||||
device="cpu",
|
||||
):
|
||||
self.embed_dim = embed_dim
|
||||
self.h_dim = h_dim
|
||||
self.out_dim = out_dim
|
||||
self.patience = patience
|
||||
self.num_heads = num_heads
|
||||
self.device = device
|
||||
self.epochs = epochs
|
||||
self.attn = AttentionModule(embed_dim, num_heads, out_dim).to(self.device)
|
||||
self.lr = lr
|
||||
self.stacking_type = attn_stacking_type
|
||||
self.tr_batch_size = 512
|
||||
self.eval_batch_size = 1024
|
||||
self.attn = AttentionModule(
|
||||
self.embed_dim, self.num_heads, self.h_dim, self.out_dim
|
||||
).to(self.device)
|
||||
|
||||
def fit(self, X, Y):
|
||||
print("- fitting Attention-based aggregating function")
|
||||
hstacked_X = self.stack(X)
|
||||
|
||||
dataset = AggregatorDatasetTorch(hstacked_X, Y)
|
||||
tra_dataloader = DataLoader(dataset, batch_size=32, shuffle=True)
|
||||
tr_lX, tr_lY, val_lX, val_lY = self.get_train_val_data(
|
||||
hstacked_X, Y, split=0.2, seed=42
|
||||
)
|
||||
|
||||
tra_dataloader = DataLoader(
|
||||
AggregatorDatasetTorch(tr_lX, tr_lY, split="train"),
|
||||
batch_size=self.tr_batch_size,
|
||||
shuffle=True,
|
||||
)
|
||||
eval_dataloader = DataLoader(
|
||||
AggregatorDatasetTorch(val_lX, val_lY, split="eval"),
|
||||
batch_size=self.eval_batch_size,
|
||||
shuffle=False,
|
||||
)
|
||||
|
||||
experiment_name = "attention_aggregator"
|
||||
trainer = Trainer(
|
||||
self.attn,
|
||||
optimizer_name="adamW",
|
||||
lr=1e-3,
|
||||
lr=self.lr,
|
||||
loss_fn=torch.nn.CrossEntropyLoss(),
|
||||
print_steps=100,
|
||||
evaluate_step=1000,
|
||||
patience=10,
|
||||
print_steps=25,
|
||||
evaluate_step=50,
|
||||
patience=self.patience,
|
||||
experiment_name=experiment_name,
|
||||
device=self.device,
|
||||
checkpoint_path="models/aggregator",
|
||||
|
@ -309,15 +354,14 @@ class AttentionAggregator:
|
|||
|
||||
trainer.train(
|
||||
train_dataloader=tra_dataloader,
|
||||
eval_dataloader=tra_dataloader,
|
||||
eval_dataloader=eval_dataloader,
|
||||
epochs=self.epochs,
|
||||
)
|
||||
return self
|
||||
|
||||
def transform(self, X):
|
||||
# TODO: implement transform
|
||||
h_stacked = self.stack(X)
|
||||
dataset = AggregatorDatasetTorch(h_stacked, lY=None, split="whole")
|
||||
hstacked_X = self.stack(X)
|
||||
dataset = AggregatorDatasetTorch(hstacked_X, lY=None, split="whole")
|
||||
dataloader = DataLoader(dataset, batch_size=32, shuffle=False)
|
||||
|
||||
_embeds = []
|
||||
|
@ -339,10 +383,13 @@ class AttentionAggregator:
|
|||
return l_embeds
|
||||
|
||||
def stack(self, data):
|
||||
hstack = self._hstack(data)
|
||||
if self.stacking_type == "concat":
|
||||
hstack = self._concat_stack(data)
|
||||
elif self.stacking_type == "mean":
|
||||
hstack = self._mean_stack(data)
|
||||
return hstack
|
||||
|
||||
def _hstack(self, data):
|
||||
def _concat_stack(self, data):
|
||||
_langs = data[0].keys()
|
||||
l_projections = {}
|
||||
for l in _langs:
|
||||
|
@ -351,8 +398,31 @@ class AttentionAggregator:
|
|||
)
|
||||
return l_projections
|
||||
|
||||
def _vstack(self, data):
|
||||
return torch.vstack()
|
||||
def _mean_stack(self, data):
|
||||
# TODO: double check this mess
|
||||
aggregated = {lang: torch.zeros(d.shape) for lang, d in data[0].items()}
|
||||
for lang_projections in data:
|
||||
for lang, projection in lang_projections.items():
|
||||
aggregated[lang] += projection
|
||||
|
||||
for lang, projection in aggregated.items():
|
||||
aggregated[lang] = (aggregated[lang] / len(data)).float()
|
||||
|
||||
return aggregated
|
||||
|
||||
def get_train_val_data(self, lX, lY, split=0.2, seed=42):
|
||||
tr_lX, tr_lY, val_lX, val_lY = {}, {}, {}, {}
|
||||
|
||||
for lang in lX.keys():
|
||||
tr_X, val_X, tr_Y, val_Y = train_test_split(
|
||||
lX[lang], lY[lang], test_size=split, random_state=seed, shuffle=False
|
||||
)
|
||||
tr_lX[lang] = tr_X
|
||||
tr_lY[lang] = tr_Y
|
||||
val_lX[lang] = val_X
|
||||
val_lY[lang] = val_Y
|
||||
|
||||
return tr_lX, tr_lY, val_lX, val_lY
|
||||
|
||||
|
||||
class AggregatorDatasetTorch(Dataset):
|
||||
|
|
|
@ -16,7 +16,6 @@ transformers.logging.set_verbosity_error()
|
|||
|
||||
|
||||
class VisualTransformerGen(ViewGen, TransformerGen):
|
||||
# TODO: probabilistic behaviour
|
||||
def __init__(
|
||||
self,
|
||||
model_name,
|
||||
|
|
41
main.py
41
main.py
|
@ -13,18 +13,23 @@ from gfun.generalizedFunnelling import GeneralizedFunnelling
|
|||
TODO:
|
||||
- add documentations sphinx
|
||||
- zero-shot setup
|
||||
- set probabilistic behaviour in Transformer parent-class
|
||||
- pooling / attention aggregation
|
||||
- load pre-trained VGFs while retaining ability to train new ones (self.fitted = True in loaded? or smt like that)
|
||||
- test split in MultiNews dataset
|
||||
- when we load a model and change its config (eg change the agg func, re-train meta), we should store this model as a new one (save it)
|
||||
"""
|
||||
|
||||
|
||||
def get_dataset(datasetname):
|
||||
assert datasetname in ["multinews", "amazon", "rcv1-2"], "dataset not supported"
|
||||
|
||||
RCV_DATAPATH = expanduser(
|
||||
"~/datasets/rcv1-2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle"
|
||||
)
|
||||
JRC_DATAPATH = expanduser(
|
||||
"~/datasets/jrc/jrc_doclist_1958-2005vs2006_all_top300_noparallel_processed_run0.pickle"
|
||||
)
|
||||
MULTINEWS_DATAPATH = expanduser("~/datasets/MultiNews/20110730/")
|
||||
|
||||
if datasetname == "multinews":
|
||||
dataset = MultiNewsDataset(
|
||||
expanduser(MULTINEWS_DATAPATH),
|
||||
|
@ -38,11 +43,9 @@ def get_dataset(datasetname):
|
|||
max_labels=args.max_labels,
|
||||
)
|
||||
elif datasetname == "rcv1-2":
|
||||
dataset = (
|
||||
MultilingualDataset(dataset_name="rcv1-2")
|
||||
.load(RCV_DATAPATH)
|
||||
.reduce_data(langs=["en", "it", "fr"], maxn=args.nrows)
|
||||
)
|
||||
dataset = MultilingualDataset(dataset_name="rcv1-2").load(RCV_DATAPATH)
|
||||
if args.nrows is not None:
|
||||
dataset.reduce_data(langs=["en", "it", "fr"], maxn=args.nrows)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
return dataset
|
||||
|
@ -55,11 +58,9 @@ def main(args):
|
|||
):
|
||||
lX, lY = dataset.training()
|
||||
lX_te, lY_te = dataset.test()
|
||||
# print("[NB: for debug purposes, training set is also used as test set]\n")
|
||||
# lX_te, lY_te = dataset.training()
|
||||
else:
|
||||
_lX = dataset.dX
|
||||
_lY = dataset.dY
|
||||
lX = dataset.dX
|
||||
lY = dataset.dY
|
||||
|
||||
tinit = time()
|
||||
|
||||
|
@ -78,6 +79,7 @@ def main(args):
|
|||
# dataset params ----------------------
|
||||
dataset_name=args.dataset,
|
||||
langs=dataset.langs(),
|
||||
num_labels=dataset.num_labels(),
|
||||
# Posterior VGF params ----------------
|
||||
posterior=args.posteriors,
|
||||
# Multilingual VGF params -------------
|
||||
|
@ -100,22 +102,20 @@ def main(args):
|
|||
aggfunc=args.aggfunc,
|
||||
optimc=args.optimc,
|
||||
load_trained=args.load_trained,
|
||||
load_meta=args.meta,
|
||||
n_jobs=args.n_jobs,
|
||||
)
|
||||
|
||||
# gfun.get_config()
|
||||
gfun.fit(lX, lY)
|
||||
|
||||
if args.load_trained is not None:
|
||||
gfun.save()
|
||||
|
||||
# if not args.load_model:
|
||||
# gfun.save()
|
||||
if args.load_trained is None:
|
||||
gfun.save(save_first_tier=True, save_meta=True)
|
||||
|
||||
preds = gfun.transform(lX)
|
||||
|
||||
train_eval = evaluate(lY, preds)
|
||||
log_eval(train_eval, phase="train")
|
||||
# train_eval = evaluate(lY, preds)
|
||||
# log_eval(train_eval, phase="train")
|
||||
|
||||
timetr = time()
|
||||
print(f"- training completed in {timetr - tinit:.2f} seconds")
|
||||
|
@ -130,10 +130,11 @@ def main(args):
|
|||
if __name__ == "__main__":
|
||||
parser = ArgumentParser()
|
||||
parser.add_argument("-l", "--load_trained", type=str, default=None)
|
||||
parser.add_argument("--meta", action="store_true")
|
||||
# Dataset parameters -------------------
|
||||
parser.add_argument("-d", "--dataset", type=str, default="multinews")
|
||||
parser.add_argument("--domains", type=str, default="all")
|
||||
parser.add_argument("--nrows", type=int, default=100)
|
||||
parser.add_argument("--nrows", type=int, default=None)
|
||||
parser.add_argument("--min_count", type=int, default=10)
|
||||
parser.add_argument("--max_labels", type=int, default=50)
|
||||
# gFUN parameters ----------------------
|
||||
|
@ -148,7 +149,7 @@ if __name__ == "__main__":
|
|||
# transformer parameters ---------------
|
||||
parser.add_argument("--transformer_name", type=str, default="mbert")
|
||||
parser.add_argument("--batch_size", type=int, default=32)
|
||||
parser.add_argument("--epochs", type=int, default=10)
|
||||
parser.add_argument("--epochs", type=int, default=1000)
|
||||
parser.add_argument("--lr", type=float, default=1e-5)
|
||||
parser.add_argument("--max_length", type=int, default=512)
|
||||
parser.add_argument("--patience", type=int, default=5)
|
||||
|
|
Loading…
Reference in New Issue