fixed loading function for Attention-based aggregating function when triggered by EarlyStopper
This commit is contained in:
parent
930a6d8275
commit
7ed98346a5
|
@ -172,6 +172,9 @@ class MultilingualDataset:
|
||||||
langs = sorted(self.multiling_dataset.keys())
|
langs = sorted(self.multiling_dataset.keys())
|
||||||
return langs
|
return langs
|
||||||
|
|
||||||
|
def num_labels(self):
|
||||||
|
return self.num_categories()
|
||||||
|
|
||||||
def num_categories(self):
|
def num_categories(self):
|
||||||
return self.lYtr()[self.langs()[0]].shape[1]
|
return self.lYtr()[self.langs()[0]].shape[1]
|
||||||
|
|
||||||
|
|
|
@ -22,6 +22,7 @@ class GeneralizedFunnelling:
|
||||||
multilingual,
|
multilingual,
|
||||||
transformer,
|
transformer,
|
||||||
langs,
|
langs,
|
||||||
|
num_labels,
|
||||||
embed_dir,
|
embed_dir,
|
||||||
n_jobs,
|
n_jobs,
|
||||||
batch_size,
|
batch_size,
|
||||||
|
@ -37,6 +38,7 @@ class GeneralizedFunnelling:
|
||||||
dataset_name,
|
dataset_name,
|
||||||
probabilistic,
|
probabilistic,
|
||||||
aggfunc,
|
aggfunc,
|
||||||
|
load_meta,
|
||||||
):
|
):
|
||||||
# Setting VFGs -----------
|
# Setting VFGs -----------
|
||||||
self.posteriors_vgf = posterior
|
self.posteriors_vgf = posterior
|
||||||
|
@ -44,7 +46,7 @@ class GeneralizedFunnelling:
|
||||||
self.multilingual_vgf = multilingual
|
self.multilingual_vgf = multilingual
|
||||||
self.trasformer_vgf = transformer
|
self.trasformer_vgf = transformer
|
||||||
self.probabilistic = probabilistic
|
self.probabilistic = probabilistic
|
||||||
self.num_labels = 73 # TODO: hard-coded
|
self.num_labels = num_labels
|
||||||
# ------------------------
|
# ------------------------
|
||||||
self.langs = langs
|
self.langs = langs
|
||||||
self.embed_dir = embed_dir
|
self.embed_dir = embed_dir
|
||||||
|
@ -68,6 +70,10 @@ class GeneralizedFunnelling:
|
||||||
self.metaclassifier = None
|
self.metaclassifier = None
|
||||||
self.aggfunc = aggfunc
|
self.aggfunc = aggfunc
|
||||||
self.load_trained = load_trained
|
self.load_trained = load_trained
|
||||||
|
self.load_first_tier = (
|
||||||
|
True # TODO: i guess we're always going to load at least the fitst tier
|
||||||
|
)
|
||||||
|
self.load_meta = load_meta
|
||||||
self.dataset_name = dataset_name
|
self.dataset_name = dataset_name
|
||||||
self._init()
|
self._init()
|
||||||
|
|
||||||
|
@ -77,11 +83,37 @@ class GeneralizedFunnelling:
|
||||||
self.aggfunc == "mean" and self.probabilistic is False
|
self.aggfunc == "mean" and self.probabilistic is False
|
||||||
), "When using averaging aggreagation function probabilistic must be True"
|
), "When using averaging aggreagation function probabilistic must be True"
|
||||||
if self.load_trained is not None:
|
if self.load_trained is not None:
|
||||||
print("- loading trained VGFs, metaclassifer and vectorizer")
|
# TODO: clean up this code here
|
||||||
self.first_tier_learners, self.metaclassifier, self.vectorizer = self.load(
|
print(
|
||||||
self.load_trained
|
"- loading trained VGFs, metaclassifer and vectorizer"
|
||||||
|
if self.load_meta
|
||||||
|
else "- loading trained VGFs and vectorizer"
|
||||||
)
|
)
|
||||||
# TODO: config like aggfunc, device, n_jobs, etc
|
self.first_tier_learners, self.metaclassifier, self.vectorizer = self.load(
|
||||||
|
self.load_trained,
|
||||||
|
load_first_tier=self.load_first_tier,
|
||||||
|
load_meta=self.load_meta,
|
||||||
|
)
|
||||||
|
if self.metaclassifier is None:
|
||||||
|
self.metaclassifier = MetaClassifier(
|
||||||
|
meta_learner=get_learner(calibrate=True, kernel="rbf"),
|
||||||
|
meta_parameters=get_params(self.optimc),
|
||||||
|
n_jobs=self.n_jobs,
|
||||||
|
)
|
||||||
|
|
||||||
|
if "attn" in self.aggfunc:
|
||||||
|
attn_stacking = self.aggfunc.split("_")[1]
|
||||||
|
self.attn_aggregator = AttentionAggregator(
|
||||||
|
embed_dim=self.get_attn_agg_dim(attn_stacking_type=attn_stacking),
|
||||||
|
out_dim=self.num_labels,
|
||||||
|
lr=self.lr_transformer,
|
||||||
|
patience=self.patience,
|
||||||
|
num_heads=1,
|
||||||
|
device=self.device,
|
||||||
|
epochs=self.epochs,
|
||||||
|
attn_stacking_type=attn_stacking,
|
||||||
|
)
|
||||||
|
return self
|
||||||
|
|
||||||
if self.posteriors_vgf:
|
if self.posteriors_vgf:
|
||||||
fun = VanillaFunGen(
|
fun = VanillaFunGen(
|
||||||
|
@ -112,7 +144,7 @@ class GeneralizedFunnelling:
|
||||||
epochs=self.epochs,
|
epochs=self.epochs,
|
||||||
batch_size=self.batch_size_transformer,
|
batch_size=self.batch_size_transformer,
|
||||||
max_length=self.max_length,
|
max_length=self.max_length,
|
||||||
device="cuda",
|
device=self.device,
|
||||||
print_steps=50,
|
print_steps=50,
|
||||||
probabilistic=self.probabilistic,
|
probabilistic=self.probabilistic,
|
||||||
evaluate_step=self.evaluate_step,
|
evaluate_step=self.evaluate_step,
|
||||||
|
@ -121,13 +153,17 @@ class GeneralizedFunnelling:
|
||||||
)
|
)
|
||||||
self.first_tier_learners.append(transformer_vgf)
|
self.first_tier_learners.append(transformer_vgf)
|
||||||
|
|
||||||
if self.aggfunc == "attn":
|
if "attn" in self.aggfunc:
|
||||||
|
attn_stacking = self.aggfunc.split("_")[1]
|
||||||
self.attn_aggregator = AttentionAggregator(
|
self.attn_aggregator = AttentionAggregator(
|
||||||
embed_dim=self.get_attn_agg_dim(),
|
embed_dim=self.get_attn_agg_dim(),
|
||||||
out_dim=self.num_labels,
|
out_dim=self.num_labels,
|
||||||
|
lr=self.lr_transformer,
|
||||||
|
patience=self.patience,
|
||||||
num_heads=1,
|
num_heads=1,
|
||||||
device=self.device,
|
device=self.device,
|
||||||
epochs=self.epochs,
|
epochs=self.epochs,
|
||||||
|
attn_stacking_type=attn_stacking,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.metaclassifier = MetaClassifier(
|
self.metaclassifier = MetaClassifier(
|
||||||
|
@ -141,6 +177,7 @@ class GeneralizedFunnelling:
|
||||||
self.multilingual_vgf,
|
self.multilingual_vgf,
|
||||||
self.wce_vgf,
|
self.wce_vgf,
|
||||||
self.trasformer_vgf,
|
self.trasformer_vgf,
|
||||||
|
self.aggfunc,
|
||||||
)
|
)
|
||||||
print(f"- model id: {self._model_id}")
|
print(f"- model id: {self._model_id}")
|
||||||
return self
|
return self
|
||||||
|
@ -153,11 +190,19 @@ class GeneralizedFunnelling:
|
||||||
def fit(self, lX, lY):
|
def fit(self, lX, lY):
|
||||||
print("[Fitting GeneralizedFunnelling]")
|
print("[Fitting GeneralizedFunnelling]")
|
||||||
if self.load_trained is not None:
|
if self.load_trained is not None:
|
||||||
print(f"- loaded trained model! Skipping training...")
|
print(
|
||||||
# TODO: add support to load only the first tier learners while re-training the metaclassifier
|
"- loaded first tier learners!"
|
||||||
load_only_first_tier = False
|
if self.load_meta is False
|
||||||
if load_only_first_tier:
|
else "- loaded trained model!"
|
||||||
raise NotImplementedError
|
)
|
||||||
|
if self.load_first_tier is True and self.load_meta is False:
|
||||||
|
# TODO: clean up this code here
|
||||||
|
projections = []
|
||||||
|
for vgf in self.first_tier_learners:
|
||||||
|
l_posteriors = vgf.transform(lX)
|
||||||
|
projections.append(l_posteriors)
|
||||||
|
agg = self.aggregate(projections, lY)
|
||||||
|
self.metaclassifier.fit(agg, lY)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
self.vectorizer.fit(lX)
|
self.vectorizer.fit(lX)
|
||||||
|
@ -191,7 +236,8 @@ class GeneralizedFunnelling:
|
||||||
aggregated = self._aggregate_mean(first_tier_projections)
|
aggregated = self._aggregate_mean(first_tier_projections)
|
||||||
elif self.aggfunc == "concat":
|
elif self.aggfunc == "concat":
|
||||||
aggregated = self._aggregate_concat(first_tier_projections)
|
aggregated = self._aggregate_concat(first_tier_projections)
|
||||||
elif self.aggfunc == "attn":
|
# elif self.aggfunc == "attn":
|
||||||
|
elif "attn" in self.aggfunc:
|
||||||
aggregated = self._aggregate_attn(first_tier_projections, lY)
|
aggregated = self._aggregate_attn(first_tier_projections, lY)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
@ -238,27 +284,41 @@ class GeneralizedFunnelling:
|
||||||
print(vgf)
|
print(vgf)
|
||||||
print("-" * 50)
|
print("-" * 50)
|
||||||
|
|
||||||
def save(self):
|
def save(self, save_first_tier=True, save_meta=True):
|
||||||
print(f"- Saving GeneralizedFunnelling model with id: {self._model_id}")
|
print(f"- Saving GeneralizedFunnelling model with id: {self._model_id}")
|
||||||
# TODO: save only the first tier learners? what about some model config + sanity checks before loading?
|
|
||||||
for vgf in self.first_tier_learners:
|
|
||||||
vgf.save_vgf(model_id=self._model_id)
|
|
||||||
os.makedirs(os.path.join("models", "metaclassifier"), exist_ok=True)
|
|
||||||
with open(
|
|
||||||
os.path.join("models", "metaclassifier", f"meta_{self._model_id}.pkl"), "wb"
|
|
||||||
) as f:
|
|
||||||
pickle.dump(self.metaclassifier, f)
|
|
||||||
os.makedirs(os.path.join("models", "vectorizer"), exist_ok=True)
|
os.makedirs(os.path.join("models", "vectorizer"), exist_ok=True)
|
||||||
with open(
|
with open(
|
||||||
os.path.join("models", "vectorizer", f"vectorizer_{self._model_id}.pkl"),
|
os.path.join("models", "vectorizer", f"vectorizer_{self._model_id}.pkl"),
|
||||||
"wb",
|
"wb",
|
||||||
) as f:
|
) as f:
|
||||||
pickle.dump(self.vectorizer, f)
|
pickle.dump(self.vectorizer, f)
|
||||||
|
|
||||||
|
if save_first_tier:
|
||||||
|
self.save_first_tier_learners(model_id=self._model_id)
|
||||||
|
|
||||||
|
if save_meta:
|
||||||
|
with open(
|
||||||
|
os.path.join("models", "metaclassifier", f"meta_{self._model_id}.pkl"),
|
||||||
|
"wb",
|
||||||
|
) as f:
|
||||||
|
pickle.dump(self.metaclassifier, f)
|
||||||
return
|
return
|
||||||
|
|
||||||
def load(self, model_id):
|
def save_first_tier_learners(self, model_id):
|
||||||
|
for vgf in self.first_tier_learners:
|
||||||
|
vgf.save_vgf(model_id=self._model_id)
|
||||||
|
return self
|
||||||
|
|
||||||
|
def load(self, model_id, load_first_tier=True, load_meta=True):
|
||||||
print(f"- loading model id: {model_id}")
|
print(f"- loading model id: {model_id}")
|
||||||
first_tier_learners = []
|
first_tier_learners = []
|
||||||
|
|
||||||
|
with open(
|
||||||
|
os.path.join("models", "vectorizer", f"vectorizer_{model_id}.pkl"), "rb"
|
||||||
|
) as f:
|
||||||
|
vectorizer = pickle.load(f)
|
||||||
|
|
||||||
if self.posteriors_vgf:
|
if self.posteriors_vgf:
|
||||||
with open(
|
with open(
|
||||||
os.path.join(
|
os.path.join(
|
||||||
|
@ -291,20 +351,43 @@ class GeneralizedFunnelling:
|
||||||
"rb",
|
"rb",
|
||||||
) as vgf:
|
) as vgf:
|
||||||
first_tier_learners.append(pickle.load(vgf))
|
first_tier_learners.append(pickle.load(vgf))
|
||||||
|
|
||||||
|
if load_meta:
|
||||||
with open(
|
with open(
|
||||||
os.path.join("models", "metaclassifier", f"meta_{model_id}.pkl"), "rb"
|
os.path.join("models", "metaclassifier", f"meta_{model_id}.pkl"), "rb"
|
||||||
) as f:
|
) as f:
|
||||||
metaclassifier = pickle.load(f)
|
metaclassifier = pickle.load(f)
|
||||||
with open(
|
else:
|
||||||
os.path.join("models", "vectorizer", f"vectorizer_{model_id}.pkl"), "rb"
|
metaclassifier = None
|
||||||
) as f:
|
|
||||||
vectorizer = pickle.load(f)
|
|
||||||
return first_tier_learners, metaclassifier, vectorizer
|
return first_tier_learners, metaclassifier, vectorizer
|
||||||
|
|
||||||
def get_attn_agg_dim(self):
|
def _load_meta(self):
|
||||||
# TODO: hardcoded for now
|
raise NotImplementedError
|
||||||
print("\n[NB: ATTN AGGREGATOR DIM HARD-CODED TO 146]\n")
|
|
||||||
return 146
|
def _load_posterior(self):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def _load_multilingual(self):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def _load_wce(self):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def _load_transformer(self):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def get_attn_agg_dim(self, attn_stacking_type="concat"):
|
||||||
|
if self.probabilistic and "attn" not in self.aggfunc:
|
||||||
|
return len(self.first_tier_learners) * self.num_labels
|
||||||
|
elif self.probabilistic and "attn" in self.aggfunc:
|
||||||
|
if attn_stacking_type == "concat":
|
||||||
|
return len(self.first_tier_learners) * self.num_labels
|
||||||
|
elif attn_stacking_type == "mean":
|
||||||
|
return self.num_labels
|
||||||
|
else:
|
||||||
|
raise NotImplementedError
|
||||||
|
else:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
def get_params(optimc=False):
|
def get_params(optimc=False):
|
||||||
|
@ -315,7 +398,7 @@ def get_params(optimc=False):
|
||||||
return [{"kernel": [kernel], "C": c_range, "gamma": ["auto"]}]
|
return [{"kernel": [kernel], "C": c_range, "gamma": ["auto"]}]
|
||||||
|
|
||||||
|
|
||||||
def get_unique_id(posterior, multilingual, wce, transformer):
|
def get_unique_id(posterior, multilingual, wce, transformer, aggfunc):
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
now = datetime.now().strftime("%y%m%d")
|
now = datetime.now().strftime("%y%m%d")
|
||||||
|
@ -324,4 +407,5 @@ def get_unique_id(posterior, multilingual, wce, transformer):
|
||||||
model_id += "m" if multilingual else ""
|
model_id += "m" if multilingual else ""
|
||||||
model_id += "w" if wce else ""
|
model_id += "w" if wce else ""
|
||||||
model_id += "t" if transformer else ""
|
model_id += "t" if transformer else ""
|
||||||
|
model_id += f"_{aggfunc}"
|
||||||
return f"{model_id}_{now}"
|
return f"{model_id}_{now}"
|
||||||
|
|
|
@ -10,6 +10,7 @@ from sklearn.feature_extraction.text import TfidfVectorizer
|
||||||
from sklearn.preprocessing import normalize
|
from sklearn.preprocessing import normalize
|
||||||
from torch.optim import AdamW
|
from torch.optim import AdamW
|
||||||
from transformers.modeling_outputs import SequenceClassifierOutput
|
from transformers.modeling_outputs import SequenceClassifierOutput
|
||||||
|
from sklearn.model_selection import train_test_split
|
||||||
|
|
||||||
from evaluation.evaluate import evaluate, log_eval
|
from evaluation.evaluate import evaluate, log_eval
|
||||||
|
|
||||||
|
@ -158,7 +159,6 @@ class Trainer:
|
||||||
self.device
|
self.device
|
||||||
)
|
)
|
||||||
break
|
break
|
||||||
# TODO: maybe a lower lr?
|
|
||||||
self.train_epoch(eval_dataloader, epoch=epoch)
|
self.train_epoch(eval_dataloader, epoch=epoch)
|
||||||
print(f"\n- last swipe on eval set")
|
print(f"\n- last swipe on eval set")
|
||||||
self.earlystopping.save_model(self.model)
|
self.earlystopping.save_model(self.model)
|
||||||
|
@ -176,7 +176,7 @@ class Trainer:
|
||||||
loss.backward()
|
loss.backward()
|
||||||
self.optimizer.step()
|
self.optimizer.step()
|
||||||
if (epoch + 1) % PRINT_ON_EPOCH == 0:
|
if (epoch + 1) % PRINT_ON_EPOCH == 0:
|
||||||
if b_idx % self.print_steps == 0:
|
if ((b_idx + 1) % self.print_steps == 0) or b_idx == 0:
|
||||||
print(f"Epoch: {epoch+1} Step: {b_idx+1} Loss: {loss:.4f}")
|
print(f"Epoch: {epoch+1} Step: {b_idx+1} Loss: {loss:.4f}")
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
@ -209,7 +209,6 @@ class Trainer:
|
||||||
|
|
||||||
|
|
||||||
class EarlyStopping:
|
class EarlyStopping:
|
||||||
# TODO: add checkpointing + restore model if early stopping + last swipe on validation set
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
patience,
|
patience,
|
||||||
|
@ -247,8 +246,8 @@ class EarlyStopping:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def save_model(self, model):
|
def save_model(self, model):
|
||||||
|
os.makedirs(self.checkpoint_path, exist_ok=True)
|
||||||
_checkpoint_dir = os.path.join(self.checkpoint_path, self.experiment_name)
|
_checkpoint_dir = os.path.join(self.checkpoint_path, self.experiment_name)
|
||||||
os.makedirs(_checkpoint_dir, exist_ok=True)
|
|
||||||
model.save_pretrained(_checkpoint_dir)
|
model.save_pretrained(_checkpoint_dir)
|
||||||
|
|
||||||
def load_model(self, model):
|
def load_model(self, model):
|
||||||
|
@ -257,51 +256,97 @@ class EarlyStopping:
|
||||||
|
|
||||||
|
|
||||||
class AttentionModule(nn.Module):
|
class AttentionModule(nn.Module):
|
||||||
def __init__(self, embed_dim, num_heads, out_dim):
|
def __init__(self, embed_dim, num_heads, h_dim, out_dim):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.attn = nn.MultiheadAttention(embed_dim, num_heads)
|
self.attn = nn.MultiheadAttention(embed_dim, num_heads, dropout=0.1)
|
||||||
|
self.layer_norm = nn.LayerNorm(embed_dim)
|
||||||
self.linear = nn.Linear(embed_dim, out_dim)
|
self.linear = nn.Linear(embed_dim, out_dim)
|
||||||
|
|
||||||
def __call__(self, X):
|
def __call__(self, X):
|
||||||
attn_out, attn_weights = self.attn(query=X, key=X, value=X)
|
out, attn_weights = self.attn(query=X, key=X, value=X)
|
||||||
out = self.linear(attn_out)
|
out = self.layer_norm(out)
|
||||||
|
out = self.linear(out)
|
||||||
|
# out = self.sigmoid(out)
|
||||||
return out
|
return out
|
||||||
|
# out = self.relu(out)
|
||||||
|
# out = self.linear2(out)
|
||||||
|
# out = self.sigmoid(out)
|
||||||
|
|
||||||
def transform(self, X):
|
def transform(self, X):
|
||||||
attn_out, attn_weights = self.attn(query=X, key=X, value=X)
|
return self.__call__(X)
|
||||||
return attn_out
|
# out, attn_weights = self.attn(query=X, key=X, value=X)
|
||||||
|
# out = self.layer_norm(out)
|
||||||
|
# out = self.linear(out)
|
||||||
|
# out = self.sigmoid(out)
|
||||||
|
# return out
|
||||||
|
# out = self.relu(out)
|
||||||
|
# out = self.linear2(out)
|
||||||
|
# out = self.sigmoid(out)
|
||||||
|
|
||||||
def save_pretrained(self, path):
|
def save_pretrained(self, path):
|
||||||
torch.save(self.state_dict(), f"{path}.pt")
|
torch.save(self, f"{path}.pt")
|
||||||
|
# torch.save(self.state_dict(), f"{path}.pt")
|
||||||
|
|
||||||
def _wtf(self):
|
def from_pretrained(self, path):
|
||||||
print("wtf")
|
return torch.load(f"{path}.pt")
|
||||||
|
|
||||||
|
|
||||||
class AttentionAggregator:
|
class AttentionAggregator:
|
||||||
def __init__(self, embed_dim, out_dim, epochs, num_heads=1, device="cpu"):
|
def __init__(
|
||||||
|
self,
|
||||||
|
embed_dim,
|
||||||
|
out_dim,
|
||||||
|
epochs,
|
||||||
|
lr,
|
||||||
|
patience,
|
||||||
|
attn_stacking_type,
|
||||||
|
h_dim=512,
|
||||||
|
num_heads=1,
|
||||||
|
device="cpu",
|
||||||
|
):
|
||||||
self.embed_dim = embed_dim
|
self.embed_dim = embed_dim
|
||||||
|
self.h_dim = h_dim
|
||||||
|
self.out_dim = out_dim
|
||||||
|
self.patience = patience
|
||||||
self.num_heads = num_heads
|
self.num_heads = num_heads
|
||||||
self.device = device
|
self.device = device
|
||||||
self.epochs = epochs
|
self.epochs = epochs
|
||||||
self.attn = AttentionModule(embed_dim, num_heads, out_dim).to(self.device)
|
self.lr = lr
|
||||||
|
self.stacking_type = attn_stacking_type
|
||||||
|
self.tr_batch_size = 512
|
||||||
|
self.eval_batch_size = 1024
|
||||||
|
self.attn = AttentionModule(
|
||||||
|
self.embed_dim, self.num_heads, self.h_dim, self.out_dim
|
||||||
|
).to(self.device)
|
||||||
|
|
||||||
def fit(self, X, Y):
|
def fit(self, X, Y):
|
||||||
print("- fitting Attention-based aggregating function")
|
print("- fitting Attention-based aggregating function")
|
||||||
hstacked_X = self.stack(X)
|
hstacked_X = self.stack(X)
|
||||||
|
|
||||||
dataset = AggregatorDatasetTorch(hstacked_X, Y)
|
tr_lX, tr_lY, val_lX, val_lY = self.get_train_val_data(
|
||||||
tra_dataloader = DataLoader(dataset, batch_size=32, shuffle=True)
|
hstacked_X, Y, split=0.2, seed=42
|
||||||
|
)
|
||||||
|
|
||||||
|
tra_dataloader = DataLoader(
|
||||||
|
AggregatorDatasetTorch(tr_lX, tr_lY, split="train"),
|
||||||
|
batch_size=self.tr_batch_size,
|
||||||
|
shuffle=True,
|
||||||
|
)
|
||||||
|
eval_dataloader = DataLoader(
|
||||||
|
AggregatorDatasetTorch(val_lX, val_lY, split="eval"),
|
||||||
|
batch_size=self.eval_batch_size,
|
||||||
|
shuffle=False,
|
||||||
|
)
|
||||||
|
|
||||||
experiment_name = "attention_aggregator"
|
experiment_name = "attention_aggregator"
|
||||||
trainer = Trainer(
|
trainer = Trainer(
|
||||||
self.attn,
|
self.attn,
|
||||||
optimizer_name="adamW",
|
optimizer_name="adamW",
|
||||||
lr=1e-3,
|
lr=self.lr,
|
||||||
loss_fn=torch.nn.CrossEntropyLoss(),
|
loss_fn=torch.nn.CrossEntropyLoss(),
|
||||||
print_steps=100,
|
print_steps=25,
|
||||||
evaluate_step=1000,
|
evaluate_step=50,
|
||||||
patience=10,
|
patience=self.patience,
|
||||||
experiment_name=experiment_name,
|
experiment_name=experiment_name,
|
||||||
device=self.device,
|
device=self.device,
|
||||||
checkpoint_path="models/aggregator",
|
checkpoint_path="models/aggregator",
|
||||||
|
@ -309,15 +354,14 @@ class AttentionAggregator:
|
||||||
|
|
||||||
trainer.train(
|
trainer.train(
|
||||||
train_dataloader=tra_dataloader,
|
train_dataloader=tra_dataloader,
|
||||||
eval_dataloader=tra_dataloader,
|
eval_dataloader=eval_dataloader,
|
||||||
epochs=self.epochs,
|
epochs=self.epochs,
|
||||||
)
|
)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def transform(self, X):
|
def transform(self, X):
|
||||||
# TODO: implement transform
|
hstacked_X = self.stack(X)
|
||||||
h_stacked = self.stack(X)
|
dataset = AggregatorDatasetTorch(hstacked_X, lY=None, split="whole")
|
||||||
dataset = AggregatorDatasetTorch(h_stacked, lY=None, split="whole")
|
|
||||||
dataloader = DataLoader(dataset, batch_size=32, shuffle=False)
|
dataloader = DataLoader(dataset, batch_size=32, shuffle=False)
|
||||||
|
|
||||||
_embeds = []
|
_embeds = []
|
||||||
|
@ -339,10 +383,13 @@ class AttentionAggregator:
|
||||||
return l_embeds
|
return l_embeds
|
||||||
|
|
||||||
def stack(self, data):
|
def stack(self, data):
|
||||||
hstack = self._hstack(data)
|
if self.stacking_type == "concat":
|
||||||
|
hstack = self._concat_stack(data)
|
||||||
|
elif self.stacking_type == "mean":
|
||||||
|
hstack = self._mean_stack(data)
|
||||||
return hstack
|
return hstack
|
||||||
|
|
||||||
def _hstack(self, data):
|
def _concat_stack(self, data):
|
||||||
_langs = data[0].keys()
|
_langs = data[0].keys()
|
||||||
l_projections = {}
|
l_projections = {}
|
||||||
for l in _langs:
|
for l in _langs:
|
||||||
|
@ -351,8 +398,31 @@ class AttentionAggregator:
|
||||||
)
|
)
|
||||||
return l_projections
|
return l_projections
|
||||||
|
|
||||||
def _vstack(self, data):
|
def _mean_stack(self, data):
|
||||||
return torch.vstack()
|
# TODO: double check this mess
|
||||||
|
aggregated = {lang: torch.zeros(d.shape) for lang, d in data[0].items()}
|
||||||
|
for lang_projections in data:
|
||||||
|
for lang, projection in lang_projections.items():
|
||||||
|
aggregated[lang] += projection
|
||||||
|
|
||||||
|
for lang, projection in aggregated.items():
|
||||||
|
aggregated[lang] = (aggregated[lang] / len(data)).float()
|
||||||
|
|
||||||
|
return aggregated
|
||||||
|
|
||||||
|
def get_train_val_data(self, lX, lY, split=0.2, seed=42):
|
||||||
|
tr_lX, tr_lY, val_lX, val_lY = {}, {}, {}, {}
|
||||||
|
|
||||||
|
for lang in lX.keys():
|
||||||
|
tr_X, val_X, tr_Y, val_Y = train_test_split(
|
||||||
|
lX[lang], lY[lang], test_size=split, random_state=seed, shuffle=False
|
||||||
|
)
|
||||||
|
tr_lX[lang] = tr_X
|
||||||
|
tr_lY[lang] = tr_Y
|
||||||
|
val_lX[lang] = val_X
|
||||||
|
val_lY[lang] = val_Y
|
||||||
|
|
||||||
|
return tr_lX, tr_lY, val_lX, val_lY
|
||||||
|
|
||||||
|
|
||||||
class AggregatorDatasetTorch(Dataset):
|
class AggregatorDatasetTorch(Dataset):
|
||||||
|
|
|
@ -16,7 +16,6 @@ transformers.logging.set_verbosity_error()
|
||||||
|
|
||||||
|
|
||||||
class VisualTransformerGen(ViewGen, TransformerGen):
|
class VisualTransformerGen(ViewGen, TransformerGen):
|
||||||
# TODO: probabilistic behaviour
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model_name,
|
model_name,
|
||||||
|
|
41
main.py
41
main.py
|
@ -13,18 +13,23 @@ from gfun.generalizedFunnelling import GeneralizedFunnelling
|
||||||
TODO:
|
TODO:
|
||||||
- add documentations sphinx
|
- add documentations sphinx
|
||||||
- zero-shot setup
|
- zero-shot setup
|
||||||
- set probabilistic behaviour in Transformer parent-class
|
- load pre-trained VGFs while retaining ability to train new ones (self.fitted = True in loaded? or smt like that)
|
||||||
- pooling / attention aggregation
|
|
||||||
- test split in MultiNews dataset
|
- test split in MultiNews dataset
|
||||||
|
- when we load a model and change its config (eg change the agg func, re-train meta), we should store this model as a new one (save it)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
def get_dataset(datasetname):
|
def get_dataset(datasetname):
|
||||||
assert datasetname in ["multinews", "amazon", "rcv1-2"], "dataset not supported"
|
assert datasetname in ["multinews", "amazon", "rcv1-2"], "dataset not supported"
|
||||||
|
|
||||||
RCV_DATAPATH = expanduser(
|
RCV_DATAPATH = expanduser(
|
||||||
"~/datasets/rcv1-2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle"
|
"~/datasets/rcv1-2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle"
|
||||||
)
|
)
|
||||||
|
JRC_DATAPATH = expanduser(
|
||||||
|
"~/datasets/jrc/jrc_doclist_1958-2005vs2006_all_top300_noparallel_processed_run0.pickle"
|
||||||
|
)
|
||||||
MULTINEWS_DATAPATH = expanduser("~/datasets/MultiNews/20110730/")
|
MULTINEWS_DATAPATH = expanduser("~/datasets/MultiNews/20110730/")
|
||||||
|
|
||||||
if datasetname == "multinews":
|
if datasetname == "multinews":
|
||||||
dataset = MultiNewsDataset(
|
dataset = MultiNewsDataset(
|
||||||
expanduser(MULTINEWS_DATAPATH),
|
expanduser(MULTINEWS_DATAPATH),
|
||||||
|
@ -38,11 +43,9 @@ def get_dataset(datasetname):
|
||||||
max_labels=args.max_labels,
|
max_labels=args.max_labels,
|
||||||
)
|
)
|
||||||
elif datasetname == "rcv1-2":
|
elif datasetname == "rcv1-2":
|
||||||
dataset = (
|
dataset = MultilingualDataset(dataset_name="rcv1-2").load(RCV_DATAPATH)
|
||||||
MultilingualDataset(dataset_name="rcv1-2")
|
if args.nrows is not None:
|
||||||
.load(RCV_DATAPATH)
|
dataset.reduce_data(langs=["en", "it", "fr"], maxn=args.nrows)
|
||||||
.reduce_data(langs=["en", "it", "fr"], maxn=args.nrows)
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
return dataset
|
return dataset
|
||||||
|
@ -55,11 +58,9 @@ def main(args):
|
||||||
):
|
):
|
||||||
lX, lY = dataset.training()
|
lX, lY = dataset.training()
|
||||||
lX_te, lY_te = dataset.test()
|
lX_te, lY_te = dataset.test()
|
||||||
# print("[NB: for debug purposes, training set is also used as test set]\n")
|
|
||||||
# lX_te, lY_te = dataset.training()
|
|
||||||
else:
|
else:
|
||||||
_lX = dataset.dX
|
lX = dataset.dX
|
||||||
_lY = dataset.dY
|
lY = dataset.dY
|
||||||
|
|
||||||
tinit = time()
|
tinit = time()
|
||||||
|
|
||||||
|
@ -78,6 +79,7 @@ def main(args):
|
||||||
# dataset params ----------------------
|
# dataset params ----------------------
|
||||||
dataset_name=args.dataset,
|
dataset_name=args.dataset,
|
||||||
langs=dataset.langs(),
|
langs=dataset.langs(),
|
||||||
|
num_labels=dataset.num_labels(),
|
||||||
# Posterior VGF params ----------------
|
# Posterior VGF params ----------------
|
||||||
posterior=args.posteriors,
|
posterior=args.posteriors,
|
||||||
# Multilingual VGF params -------------
|
# Multilingual VGF params -------------
|
||||||
|
@ -100,22 +102,20 @@ def main(args):
|
||||||
aggfunc=args.aggfunc,
|
aggfunc=args.aggfunc,
|
||||||
optimc=args.optimc,
|
optimc=args.optimc,
|
||||||
load_trained=args.load_trained,
|
load_trained=args.load_trained,
|
||||||
|
load_meta=args.meta,
|
||||||
n_jobs=args.n_jobs,
|
n_jobs=args.n_jobs,
|
||||||
)
|
)
|
||||||
|
|
||||||
# gfun.get_config()
|
# gfun.get_config()
|
||||||
gfun.fit(lX, lY)
|
gfun.fit(lX, lY)
|
||||||
|
|
||||||
if args.load_trained is not None:
|
if args.load_trained is None:
|
||||||
gfun.save()
|
gfun.save(save_first_tier=True, save_meta=True)
|
||||||
|
|
||||||
# if not args.load_model:
|
|
||||||
# gfun.save()
|
|
||||||
|
|
||||||
preds = gfun.transform(lX)
|
preds = gfun.transform(lX)
|
||||||
|
|
||||||
train_eval = evaluate(lY, preds)
|
# train_eval = evaluate(lY, preds)
|
||||||
log_eval(train_eval, phase="train")
|
# log_eval(train_eval, phase="train")
|
||||||
|
|
||||||
timetr = time()
|
timetr = time()
|
||||||
print(f"- training completed in {timetr - tinit:.2f} seconds")
|
print(f"- training completed in {timetr - tinit:.2f} seconds")
|
||||||
|
@ -130,10 +130,11 @@ def main(args):
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = ArgumentParser()
|
parser = ArgumentParser()
|
||||||
parser.add_argument("-l", "--load_trained", type=str, default=None)
|
parser.add_argument("-l", "--load_trained", type=str, default=None)
|
||||||
|
parser.add_argument("--meta", action="store_true")
|
||||||
# Dataset parameters -------------------
|
# Dataset parameters -------------------
|
||||||
parser.add_argument("-d", "--dataset", type=str, default="multinews")
|
parser.add_argument("-d", "--dataset", type=str, default="multinews")
|
||||||
parser.add_argument("--domains", type=str, default="all")
|
parser.add_argument("--domains", type=str, default="all")
|
||||||
parser.add_argument("--nrows", type=int, default=100)
|
parser.add_argument("--nrows", type=int, default=None)
|
||||||
parser.add_argument("--min_count", type=int, default=10)
|
parser.add_argument("--min_count", type=int, default=10)
|
||||||
parser.add_argument("--max_labels", type=int, default=50)
|
parser.add_argument("--max_labels", type=int, default=50)
|
||||||
# gFUN parameters ----------------------
|
# gFUN parameters ----------------------
|
||||||
|
@ -148,7 +149,7 @@ if __name__ == "__main__":
|
||||||
# transformer parameters ---------------
|
# transformer parameters ---------------
|
||||||
parser.add_argument("--transformer_name", type=str, default="mbert")
|
parser.add_argument("--transformer_name", type=str, default="mbert")
|
||||||
parser.add_argument("--batch_size", type=int, default=32)
|
parser.add_argument("--batch_size", type=int, default=32)
|
||||||
parser.add_argument("--epochs", type=int, default=10)
|
parser.add_argument("--epochs", type=int, default=1000)
|
||||||
parser.add_argument("--lr", type=float, default=1e-5)
|
parser.add_argument("--lr", type=float, default=1e-5)
|
||||||
parser.add_argument("--max_length", type=int, default=512)
|
parser.add_argument("--max_length", type=int, default=512)
|
||||||
parser.add_argument("--patience", type=int, default=5)
|
parser.add_argument("--patience", type=int, default=5)
|
||||||
|
|
Loading…
Reference in New Issue