implemented multimodal pipeline; gFunDataset interface; fixed imports
This commit is contained in:
parent
7041f7b651
commit
0c9454cdd4
|
@ -180,3 +180,4 @@ amazon_cateogories.bu.txt
|
|||
models/*
|
||||
scripts/
|
||||
logger/*
|
||||
explore_data.ipynb
|
|
@ -0,0 +1,223 @@
|
|||
from sklearn.preprocessing import MultiLabelBinarizer, LabelBinarizer
|
||||
from dataManager.glamiDataset import get_dataframe
|
||||
from dataManager.multilingualDatset import MultilingualDataset
|
||||
|
||||
|
||||
class gFunDataset:
|
||||
def __init__(
|
||||
self,
|
||||
dataset_dir,
|
||||
is_textual,
|
||||
is_visual,
|
||||
is_multilabel,
|
||||
labels=None,
|
||||
nrows=None,
|
||||
data_langs=None,
|
||||
):
|
||||
self.dataset_dir = dataset_dir
|
||||
self.data_langs = data_langs
|
||||
self.is_textual = is_textual
|
||||
self.is_visual = is_visual
|
||||
self.is_multilabel = is_multilabel
|
||||
self.labels = labels
|
||||
self.nrows = nrows
|
||||
self.dataset = {}
|
||||
self.load_dataset()
|
||||
|
||||
def get_label_binarizer(self, labels):
|
||||
if self.dataset_name in ["rcv1-2", "jrc"]:
|
||||
mlb = "Labels are already binarized for rcv1-2 dataset"
|
||||
elif self.is_multilabel:
|
||||
mlb = MultiLabelBinarizer()
|
||||
mlb.fit([labels])
|
||||
else:
|
||||
mlb = LabelBinarizer()
|
||||
mlb.fit(labels)
|
||||
return mlb
|
||||
|
||||
def load_dataset(self):
|
||||
if "glami" in self.dataset_dir.lower():
|
||||
print(f"- Loading GLAMI dataset from {self.dataset_dir}")
|
||||
self.dataset_name = "glami"
|
||||
self.dataset, self.labels, self.data_langs = self._load_glami(
|
||||
self.dataset_dir, self.nrows
|
||||
)
|
||||
self.mlb = self.get_label_binarizer(self.labels)
|
||||
|
||||
elif "rcv" in self.dataset_dir.lower():
|
||||
print(f"- Loading RCV1-2 dataset from {self.dataset_dir}")
|
||||
self.dataset_name = "rcv1-2"
|
||||
self.dataset, self.labels, self.data_langs = self._load_multilingual(
|
||||
self.dataset_name, self.dataset_dir, self.nrows
|
||||
)
|
||||
self.mlb = self.get_label_binarizer(self.labels)
|
||||
|
||||
elif "jrc" in self.dataset_dir.lower():
|
||||
print(f"- Loading JRC dataset from {self.dataset_dir}")
|
||||
self.dataset_name = "jrc"
|
||||
self.dataset, self.labels, self.data_langs = self._load_multilingual(
|
||||
self.dataset_name, self.dataset_dir, self.nrows
|
||||
)
|
||||
self.mlb = self.get_label_binarizer(self.labels)
|
||||
|
||||
self.show_dimension()
|
||||
|
||||
return
|
||||
|
||||
def show_dimension(self):
|
||||
print(f"\n[Dataset: {self.dataset_name.upper()}]")
|
||||
for lang, data in self.dataset.items():
|
||||
print(
|
||||
f"-- Lang: {lang} - train docs: {len(data['train']['text'])} - test docs: {len(data['test']['text'])}"
|
||||
)
|
||||
if self.dataset_name in ["rcv1-2", "jrc"]:
|
||||
print(f"-- Labels: {self.labels}")
|
||||
else:
|
||||
print(f"-- Labels: {len(self.labels)}")
|
||||
|
||||
def _load_multilingual(self, dataset_name, dataset_dir, nrows):
|
||||
old_dataset = MultilingualDataset(dataset_name=dataset_name).load(dataset_dir)
|
||||
if nrows is not None:
|
||||
old_dataset.reduce_data(langs=["en", "it", "fr"], maxn=nrows)
|
||||
labels = old_dataset.num_labels()
|
||||
data_langs = old_dataset.langs()
|
||||
|
||||
def _format_multilingual(data):
|
||||
text = data[0]
|
||||
image = None
|
||||
labels = data[1]
|
||||
return {"text": text, "image": image, "label": labels}
|
||||
|
||||
dataset = {
|
||||
k: {"train": _format_multilingual(v[0]), "test": _format_multilingual(v[1])}
|
||||
for k, v in old_dataset.multiling_dataset.items()
|
||||
}
|
||||
return dataset, labels, data_langs
|
||||
|
||||
def _load_glami(self, dataset_dir, nrows):
|
||||
def _balanced_sample(data, n, remainder=0):
|
||||
import pandas as pd
|
||||
|
||||
langs = sorted(data.geo.unique().tolist())
|
||||
dict_n = {lang: n for lang in langs}
|
||||
dict_n[langs[0]] += remainder
|
||||
|
||||
sampled = []
|
||||
for lang in langs:
|
||||
sampled.append(data[data.geo == lang].sample(n=dict_n[lang]))
|
||||
|
||||
return pd.concat(sampled, axis=0)
|
||||
|
||||
# TODO: set this sampling as determinsitic/dependeing on the seed
|
||||
lang_nrows = (
|
||||
nrows // 13 if self.data_langs is None else nrows // len(self.data_langs)
|
||||
) # GLAMI 1-M has 13 languages
|
||||
remainder = (
|
||||
nrows % 13 if self.data_langs is None else nrows % len(self.data_langs)
|
||||
)
|
||||
|
||||
train_split = get_dataframe("train", dataset_dir=dataset_dir)
|
||||
train_split = _balanced_sample(train_split, lang_nrows, remainder=remainder)
|
||||
|
||||
if self.data_langs is None:
|
||||
data_langs = sorted(train_split.geo.unique().tolist())
|
||||
# TODO: if data langs is NOT none then we have a problem where we filter df by langs
|
||||
if self.labels is None:
|
||||
labels = train_split.category_name.unique().tolist()
|
||||
|
||||
# TODO: atm test data should contain same languages as train data
|
||||
test_split = get_dataframe("test", dataset_dir=dataset_dir)
|
||||
# TODO: atm we're using 1:1 train-test
|
||||
test_split = _balanced_sample(test_split, lang_nrows, remainder=remainder)
|
||||
|
||||
gb_train = train_split.groupby("geo")
|
||||
gb_test = test_split.groupby("geo")
|
||||
|
||||
def _format_glami(data_df):
|
||||
text = (data_df.name + " " + data_df.description).tolist()
|
||||
image = data_df.image_file.tolist()
|
||||
labels = data_df.category_name.tolist()
|
||||
return {"text": text, "image": image, "label": labels}
|
||||
|
||||
dataset = {
|
||||
lang: {
|
||||
"train": _format_glami(data_tr),
|
||||
"test": _format_glami(gb_test.get_group(lang)),
|
||||
}
|
||||
for lang, data_tr in gb_train
|
||||
if lang in data_langs
|
||||
}
|
||||
|
||||
return dataset, labels, data_langs
|
||||
|
||||
def binarize_labels(self, labels):
|
||||
if self.dataset_name in ["rcv1-2", "jrc"]:
|
||||
# labels are already binarized for rcv1-2 dataset
|
||||
return labels
|
||||
if hasattr(self, "mlb"):
|
||||
return self.mlb.transform(labels)
|
||||
else:
|
||||
raise AttributeError("Label binarizer not found")
|
||||
|
||||
def training(self):
|
||||
lXtr = {}
|
||||
lYtr = {}
|
||||
for lang in self.data_langs:
|
||||
text = self.dataset[lang]["train"]["text"] if self.is_textual else None
|
||||
img = self.dataset[lang]["train"]["image"] if self.is_visual else None
|
||||
labels = self.dataset[lang]["train"]["label"]
|
||||
|
||||
lXtr[lang] = {"text": text, "image": img}
|
||||
lYtr[lang] = self.binarize_labels(labels)
|
||||
|
||||
return lXtr, lYtr
|
||||
|
||||
def test(self):
|
||||
lXte = {}
|
||||
lYte = {}
|
||||
for lang in self.data_langs:
|
||||
text = self.dataset[lang]["test"]["text"] if self.is_textual else None
|
||||
img = self.dataset[lang]["test"]["image"] if self.is_visual else None
|
||||
labels = self.dataset[lang]["test"]["label"]
|
||||
|
||||
lXte[lang] = {"text": text, "image": img}
|
||||
lYte[lang] = self.binarize_labels(labels)
|
||||
|
||||
return lXte, lYte
|
||||
|
||||
def langs(self):
|
||||
return self.data_langs
|
||||
|
||||
def num_labels(self):
|
||||
if self.dataset_name not in ["rcv1-2", "jrc"]:
|
||||
return len(self.labels)
|
||||
else:
|
||||
return self.labels
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import os
|
||||
|
||||
GLAMI_DATAPATH = os.path.expanduser("~/datasets/GLAMI-1M-dataset")
|
||||
RCV_DATAPATH = os.path.expanduser(
|
||||
"~/datasets/rcv1-2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle"
|
||||
)
|
||||
JRC_DATAPATH = os.path.expanduser(
|
||||
"~/datasets/jrc/jrc_doclist_1958-2005vs2006_all_top300_noparallel_processed_run0.pickle"
|
||||
)
|
||||
|
||||
print("Hello gFunDataset")
|
||||
dataset = gFunDataset(
|
||||
# dataset_dir=GLAMI_DATAPATH,
|
||||
# dataset_dir=RCV_DATAPATH,
|
||||
dataset_dir=JRC_DATAPATH,
|
||||
data_langs=None,
|
||||
is_textual=True,
|
||||
is_visual=True,
|
||||
is_multilabel=False,
|
||||
labels=None,
|
||||
nrows=13,
|
||||
)
|
||||
lXtr, lYtr = dataset.training()
|
||||
lXte, lYte = dataset.test()
|
||||
exit(0)
|
|
@ -100,6 +100,26 @@ class GlamiDataset:
|
|||
self.nrows = nrows
|
||||
self.multilingual_dataset = {}
|
||||
|
||||
"""
|
||||
self.multilingual_multimodal_dataset = {
|
||||
lang: {
|
||||
"text": {txt_data},
|
||||
"image": {img_data},
|
||||
}
|
||||
}
|
||||
TODO: if we decide to do this, we need to change both the
|
||||
training (e.g. vectorizer should call "text") and also the
|
||||
multilingual unimodal dataset (to include the field "text" only).
|
||||
BUT this will be a pain when we split/shuffle the datasets.
|
||||
I think it is better to have smt like this:
|
||||
|
||||
self.ml_mm_dataset = {
|
||||
"lang": (txt_data, img_data)
|
||||
}
|
||||
|
||||
but then also the unimodal dataset should have a "lang": (txt_data, _) value
|
||||
"""
|
||||
|
||||
def num_labels(self):
|
||||
return len(self.labels)
|
||||
|
||||
|
@ -143,7 +163,7 @@ class GlamiDataset:
|
|||
def training(self):
|
||||
# TODO: tolist() or ???
|
||||
lXtr = {
|
||||
lang: (df.name + " " + df.description).tolist()
|
||||
lang: ((df.name + " " + df.description).tolist(), df.image_file.tolist())
|
||||
for lang, (df, _) in self.multilingual_dataset.items()
|
||||
}
|
||||
lYtr = {
|
||||
|
@ -154,7 +174,7 @@ class GlamiDataset:
|
|||
|
||||
def test(self):
|
||||
lXte = {
|
||||
lang: (df.name + " " + df.description).tolist()
|
||||
lang: ((df.name + " " + df.description).tolist(), df.image_file.tolist())
|
||||
for lang, (_, df) in self.multilingual_dataset.items()
|
||||
}
|
||||
lYte = {
|
||||
|
|
|
@ -86,7 +86,7 @@ class MultiNewsDataset:
|
|||
lXtr = {}
|
||||
lYtr = {}
|
||||
for lang, data in self.lang_multiModalDataset.items():
|
||||
_data = [clean_text for _, clean_text, _, _ in data.data]
|
||||
_data = [(clean_text, img) for _, clean_text, _, img in data.data]
|
||||
lXtr[lang] = _data
|
||||
lYtr = {
|
||||
lang: self.label_binarizer.transform(data.labels)
|
||||
|
|
|
@ -64,7 +64,7 @@ class MultilingualDataset:
|
|||
def __init__(self, dataset_name):
|
||||
self.dataset_name = dataset_name
|
||||
self.multiling_dataset = {}
|
||||
print(f"[Init Multilingual Dataset: {self.dataset_name}]")
|
||||
# print(f"[Init Multilingual Dataset: {self.dataset_name}]")
|
||||
|
||||
def add(self, lang, Xtr, Ytr, Xte, Yte, tr_ids=None, te_ids=None):
|
||||
self.multiling_dataset[lang] = ((Xtr, Ytr, tr_ids), (Xte, Yte, te_ids))
|
||||
|
|
|
@ -7,9 +7,8 @@ def evaluation_metrics(y, y_):
|
|||
if len(y.shape) == len(y_.shape) == 1 and len(np.unique(y)) > 2: # single-label
|
||||
raise NotImplementedError() # return f1_score(y,y_,average='macro'), f1_score(y,y_,average='micro')
|
||||
else: # the metrics I implemented assume multiclass multilabel classification as binary classifiers
|
||||
# return macroF1(y, y_), microF1(y, y_), macroK(y, y_), microK(y, y_), macroP(y, y_), microP(y, y_), macroR(y, y_), microR(y, y_)
|
||||
# return macroF1(y, y_), microF1(y, y_), macroAcc(y, y_), microAcc(y, y_), macroP(y, y_), microP(y, y_), macroR(y, y_), microR(y, y_), macroAcc(y, y_)
|
||||
return macroF1(y, y_), microF1(y, y_), macroK(y, y_), microK(y, y_)
|
||||
# return macroF1(y, y_), microF1(y, y_), macroK(y, y_), macroAcc(y, y_)
|
||||
|
||||
|
||||
def evaluate(ly_true, ly_pred, metrics=evaluation_metrics, n_jobs=-1):
|
||||
|
|
|
@ -235,3 +235,7 @@ def macroK(true_labels, predicted_labels):
|
|||
# true_labels and predicted_labels are two matrices in sklearn.preprocessing.MultiLabelBinarizer format
|
||||
def microK(true_labels, predicted_labels):
|
||||
return micro_average(true_labels, predicted_labels, K)
|
||||
|
||||
|
||||
def macroAcc(true_labels, predicted_labels):
|
||||
return macro_average(true_labels, predicted_labels, accuracy)
|
||||
|
|
|
@ -1,17 +1,18 @@
|
|||
import os
|
||||
import sys
|
||||
|
||||
sys.path.append(os.path.join(os.getcwd(), "gfun"))
|
||||
# sys.path.append(os.path.join(os.getcwd(), "gfun"))
|
||||
|
||||
import pickle
|
||||
|
||||
import numpy as np
|
||||
from vgfs.commons import TfidfVectorizerMultilingual, AttentionAggregator
|
||||
from vgfs.learners.svms import MetaClassifier, get_learner
|
||||
from vgfs.multilingualGen import MultilingualGen
|
||||
from gfun.vgfs.commons import TfidfVectorizerMultilingual, AttentionAggregator
|
||||
from gfun.vgfs.learners.svms import MetaClassifier, get_learner
|
||||
from gfun.vgfs.multilingualGen import MultilingualGen
|
||||
from gfun.vgfs.textualTransformerGen import TextualTransformerGen
|
||||
from vgfs.vanillaFun import VanillaFunGen
|
||||
from vgfs.wceGen import WceGen
|
||||
from gfun.vgfs.visualTransformerGen import VisualTransformerGen
|
||||
from gfun.vgfs.vanillaFun import VanillaFunGen
|
||||
from gfun.vgfs.wceGen import WceGen
|
||||
|
||||
|
||||
class GeneralizedFunnelling:
|
||||
|
@ -20,7 +21,8 @@ class GeneralizedFunnelling:
|
|||
posterior,
|
||||
wce,
|
||||
multilingual,
|
||||
transformer,
|
||||
textual_transformer,
|
||||
visual_transformer,
|
||||
langs,
|
||||
num_labels,
|
||||
embed_dir,
|
||||
|
@ -31,7 +33,8 @@ class GeneralizedFunnelling:
|
|||
epochs,
|
||||
patience,
|
||||
evaluate_step,
|
||||
transformer_name,
|
||||
textual_transformer_name,
|
||||
visual_transformer_name,
|
||||
optimc,
|
||||
device,
|
||||
load_trained,
|
||||
|
@ -44,15 +47,16 @@ class GeneralizedFunnelling:
|
|||
self.posteriors_vgf = posterior
|
||||
self.wce_vgf = wce
|
||||
self.multilingual_vgf = multilingual
|
||||
self.trasformer_vgf = transformer
|
||||
self.trasformer_vgf = textual_transformer
|
||||
self.visual_transformer_vgf = visual_transformer
|
||||
self.probabilistic = probabilistic
|
||||
self.num_labels = num_labels
|
||||
# ------------------------
|
||||
self.langs = langs
|
||||
self.embed_dir = embed_dir
|
||||
self.cached = True
|
||||
# Transformer VGF params ----------
|
||||
self.transformer_name = transformer_name
|
||||
# Textual Transformer VGF params ----------
|
||||
self.textaul_transformer_name = textual_transformer_name
|
||||
self.epochs = epochs
|
||||
self.lr_transformer = lr
|
||||
self.batch_size_transformer = batch_size
|
||||
|
@ -61,6 +65,8 @@ class GeneralizedFunnelling:
|
|||
self.patience = patience
|
||||
self.evaluate_step = evaluate_step
|
||||
self.device = device
|
||||
# Visual Transformer VGF params ----------
|
||||
self.visual_transformer_name = visual_transformer_name
|
||||
# Metaclassifier params ------------
|
||||
self.optimc = optimc
|
||||
# -------------------
|
||||
|
@ -78,7 +84,7 @@ class GeneralizedFunnelling:
|
|||
self._init()
|
||||
|
||||
def _init(self):
|
||||
print("[Init GeneralizedFunnelling]")
|
||||
print("\n[Init GeneralizedFunnelling]")
|
||||
assert not (
|
||||
self.aggfunc == "mean" and self.probabilistic is False
|
||||
), "When using averaging aggreagation function probabilistic must be True"
|
||||
|
@ -139,20 +145,35 @@ class GeneralizedFunnelling:
|
|||
if self.trasformer_vgf:
|
||||
transformer_vgf = TextualTransformerGen(
|
||||
dataset_name=self.dataset_name,
|
||||
model_name=self.transformer_name,
|
||||
model_name=self.textaul_transformer_name,
|
||||
lr=self.lr_transformer,
|
||||
epochs=self.epochs,
|
||||
batch_size=self.batch_size_transformer,
|
||||
max_length=self.max_length,
|
||||
device=self.device,
|
||||
print_steps=50,
|
||||
probabilistic=self.probabilistic,
|
||||
evaluate_step=self.evaluate_step,
|
||||
verbose=True,
|
||||
patience=self.patience,
|
||||
device=self.device,
|
||||
)
|
||||
self.first_tier_learners.append(transformer_vgf)
|
||||
|
||||
if self.visual_transformer_vgf:
|
||||
visual_trasformer_vgf = VisualTransformerGen(
|
||||
dataset_name=self.dataset_name,
|
||||
model_name="vit",
|
||||
lr=1e-5, # self.lr_visual_transformer,
|
||||
epochs=self.epochs,
|
||||
batch_size=32, # self.batch_size_visual_transformer,
|
||||
# batch_size_eval=128,
|
||||
probabilistic=self.probabilistic,
|
||||
evaluate_step=self.evaluate_step,
|
||||
patience=self.patience,
|
||||
device=self.device,
|
||||
)
|
||||
self.first_tier_learners.append(visual_trasformer_vgf)
|
||||
|
||||
if "attn" in self.aggfunc:
|
||||
attn_stacking = self.aggfunc.split("_")[1]
|
||||
self.attn_aggregator = AttentionAggregator(
|
||||
|
@ -189,15 +210,18 @@ class GeneralizedFunnelling:
|
|||
vgf.vectorizer = self.vectorizer
|
||||
|
||||
def fit(self, lX, lY):
|
||||
print("[Fitting GeneralizedFunnelling]")
|
||||
print("\n[Fitting GeneralizedFunnelling]")
|
||||
if self.load_trained is not None:
|
||||
print(
|
||||
"- loaded first tier learners!"
|
||||
if self.load_meta is False
|
||||
else "- loaded trained model!"
|
||||
)
|
||||
"""
|
||||
if we are only loading the first tier, we need to
|
||||
transform the training data to train the meta-classifier
|
||||
"""
|
||||
if self.load_first_tier is True and self.load_meta is False:
|
||||
# TODO: clean up this code here
|
||||
projections = []
|
||||
for vgf in self.first_tier_learners:
|
||||
l_posteriors = vgf.transform(lX)
|
||||
|
@ -403,7 +427,7 @@ def get_unique_id(dataset_name, posterior, multilingual, wce, transformer, aggfu
|
|||
from datetime import datetime
|
||||
|
||||
now = datetime.now().strftime("%y%m%d")
|
||||
model_id = dataset_name
|
||||
model_id = f"{dataset_name}_"
|
||||
model_id += "p" if posterior else ""
|
||||
model_id += "m" if multilingual else ""
|
||||
model_id += "w" if wce else ""
|
||||
|
|
|
@ -4,17 +4,17 @@ from collections import defaultdict
|
|||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
from sklearn.decomposition import TruncatedSVD
|
||||
from sklearn.feature_extraction.text import TfidfVectorizer
|
||||
from sklearn.model_selection import train_test_split
|
||||
from sklearn.preprocessing import normalize
|
||||
from torch.optim import AdamW
|
||||
from transformers.modeling_outputs import SequenceClassifierOutput
|
||||
from sklearn.model_selection import train_test_split
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
from transformers.modeling_outputs import ModelOutput
|
||||
|
||||
from evaluation.evaluate import evaluate, log_eval
|
||||
|
||||
PRINT_ON_EPOCH = 10
|
||||
PRINT_ON_EPOCH = 1
|
||||
|
||||
|
||||
def _normalize(lX, l2=True):
|
||||
|
@ -78,12 +78,12 @@ class TfidfVectorizerMultilingual:
|
|||
def fit(self, lX, ly=None):
|
||||
self.langs = sorted(lX.keys())
|
||||
self.vectorizer = {
|
||||
l: TfidfVectorizer(**self.kwargs).fit(lX[l]) for l in self.langs
|
||||
l: TfidfVectorizer(**self.kwargs).fit(lX[l]["text"]) for l in self.langs
|
||||
}
|
||||
return self
|
||||
|
||||
def transform(self, lX):
|
||||
return {l: self.vectorizer[l].transform(lX[l]) for l in self.langs}
|
||||
return {l: self.vectorizer[l].transform(lX[l]["text"]) for l in self.langs}
|
||||
|
||||
def fit_transform(self, lX, ly=None):
|
||||
return self.fit(lX, ly).transform(lX)
|
||||
|
@ -123,6 +123,7 @@ class Trainer:
|
|||
self.print_steps = print_steps
|
||||
self.experiment_name = experiment_name
|
||||
self.patience = patience
|
||||
self.print_eval = evaluate_step
|
||||
self.earlystopping = EarlyStopping(
|
||||
patience=patience,
|
||||
checkpoint_path=checkpoint_path,
|
||||
|
@ -144,12 +145,15 @@ class Trainer:
|
|||
- train batch size: {train_dataloader.batch_size}
|
||||
- eval batch size: {eval_dataloader.batch_size}
|
||||
- max len: {train_dataloader.dataset.X.shape[-1]}
|
||||
- patience: {self.earlystopping.patience}\n"""
|
||||
- patience: {self.earlystopping.patience}
|
||||
- evaluate every: {self.evaluate_steps}
|
||||
- print eval every: {self.print_eval}
|
||||
- print train steps: {self.print_steps}\n"""
|
||||
)
|
||||
for epoch in range(epochs):
|
||||
self.train_epoch(train_dataloader, epoch)
|
||||
if (epoch + 1) % self.evaluate_steps == 0:
|
||||
print_eval = (epoch + 1) % 25 == 0
|
||||
print_eval = (epoch + 1) % self.print_eval == 0
|
||||
metric_watcher = self.evaluate(eval_dataloader, print_eval=print_eval)
|
||||
stop = self.earlystopping(metric_watcher, self.model, epoch + 1)
|
||||
if stop:
|
||||
|
@ -160,7 +164,7 @@ class Trainer:
|
|||
self.device
|
||||
)
|
||||
break
|
||||
print(f"\n- last swipe on eval set")
|
||||
print(f"- last swipe on eval set")
|
||||
self.train_epoch(eval_dataloader, epoch=0)
|
||||
self.earlystopping.save_model(self.model)
|
||||
return self.model
|
||||
|
@ -170,7 +174,7 @@ class Trainer:
|
|||
for b_idx, (x, y, lang) in enumerate(dataloader):
|
||||
self.optimizer.zero_grad()
|
||||
y_hat = self.model(x.to(self.device))
|
||||
if isinstance(y_hat, SequenceClassifierOutput):
|
||||
if isinstance(y_hat, ModelOutput):
|
||||
loss = self.loss_fn(y_hat.logits, y.to(self.device))
|
||||
else:
|
||||
loss = self.loss_fn(y_hat, y.to(self.device))
|
||||
|
@ -189,7 +193,7 @@ class Trainer:
|
|||
|
||||
for b_idx, (x, y, lang) in enumerate(dataloader):
|
||||
y_hat = self.model(x.to(self.device))
|
||||
if isinstance(y_hat, SequenceClassifierOutput):
|
||||
if isinstance(y_hat, ModelOutput):
|
||||
loss = self.loss_fn(y_hat.logits, y.to(self.device))
|
||||
predictions = predict(y_hat.logits, classification_type="multilabel")
|
||||
else:
|
||||
|
@ -272,6 +276,10 @@ class AttentionModule(nn.Module):
|
|||
self.linear = nn.Linear(embed_dim, out_dim)
|
||||
self.sigmoid = nn.Sigmoid()
|
||||
|
||||
def init_weights(self, mode="mean"):
|
||||
# TODO: add init function of the attention module: either all weights are positive or set to 1/num_classes
|
||||
raise NotImplementedError
|
||||
|
||||
def __call__(self, X):
|
||||
out, attn_weights = self.attn(query=X, key=X, value=X)
|
||||
# out = self.layer_norm(out)
|
||||
|
|
|
@ -4,9 +4,9 @@ import torch
|
|||
import numpy as np
|
||||
from torchtext.vocab import Vectors
|
||||
from joblib import Parallel, delayed
|
||||
from vgfs.viewGen import ViewGen
|
||||
from vgfs.commons import _normalize, XdotM
|
||||
from vgfs.learners.svms import FeatureSet2Posteriors
|
||||
from gfun.vgfs.viewGen import ViewGen
|
||||
from gfun.vgfs.commons import _normalize, XdotM
|
||||
from gfun.vgfs.learners.svms import FeatureSet2Posteriors
|
||||
|
||||
|
||||
class MultilingualGen(ViewGen):
|
||||
|
|
|
@ -7,14 +7,14 @@ from collections import defaultdict
|
|||
import numpy as np
|
||||
import torch
|
||||
import transformers
|
||||
from sklearn.model_selection import train_test_split
|
||||
from torch.optim import AdamW
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
|
||||
# from sklearn.model_selection import train_test_split
|
||||
# from torch.optim import AdamW
|
||||
from torch.utils.data import Dataset
|
||||
from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
||||
from vgfs.learners.svms import FeatureSet2Posteriors
|
||||
from vgfs.viewGen import ViewGen
|
||||
from vgfs.transformerGen import TransformerGen
|
||||
from vgfs.commons import Trainer, predict
|
||||
from gfun.vgfs.commons import Trainer
|
||||
from gfun.vgfs.transformerGen import TransformerGen
|
||||
from gfun.vgfs.viewGen import ViewGen
|
||||
|
||||
transformers.logging.set_verbosity_error()
|
||||
|
||||
|
@ -104,7 +104,7 @@ class TextualTransformerGen(ViewGen, TransformerGen):
|
|||
)
|
||||
|
||||
tr_lX, tr_lY, val_lX, val_lY = self.get_train_val_data(
|
||||
lX, lY, split=0.2, seed=42
|
||||
lX, lY, split=0.2, seed=42, modality="text"
|
||||
)
|
||||
|
||||
tra_dataloader = self.build_dataloader(
|
||||
|
@ -156,6 +156,8 @@ class TextualTransformerGen(ViewGen, TransformerGen):
|
|||
return self
|
||||
|
||||
def transform(self, lX):
|
||||
# forcing to only text modality
|
||||
lX = {lang: data["text"] for lang, data in lX.items()}
|
||||
_embeds = []
|
||||
l_embeds = defaultdict(list)
|
||||
|
||||
|
@ -196,8 +198,8 @@ class TextualTransformerGen(ViewGen, TransformerGen):
|
|||
from os import makedirs
|
||||
from os.path import join
|
||||
|
||||
vgf_name = "transformerGen"
|
||||
_basedir = join("models", "vgfs", "transformer")
|
||||
vgf_name = "textualTransformerGen"
|
||||
_basedir = join("models", "vgfs", "textual_transformer")
|
||||
makedirs(_basedir, exist_ok=True)
|
||||
_path = join(_basedir, f"{vgf_name}_{model_id}.pkl")
|
||||
with open(_path, "wb") as f:
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
from sklearn.model_selection import train_test_split
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
from vgfs.learners.svms import FeatureSet2Posteriors
|
||||
from gfun.vgfs.learners.svms import FeatureSet2Posteriors
|
||||
|
||||
|
||||
class TransformerGen:
|
||||
|
@ -67,16 +67,26 @@ class TransformerGen:
|
|||
split="train",
|
||||
shuffle=False,
|
||||
):
|
||||
l_tokenized = {lang: processor_fn(data) for lang, data in lX.items()}
|
||||
self.datasets[split] = torchDataset(l_tokenized, lY, split=split)
|
||||
return DataLoader(self.datasets[split], batch_size=batch_size, shuffle=shuffle)
|
||||
l_processed = {lang: processor_fn(lX[lang]) for lang in lX.keys()}
|
||||
self.datasets[split] = torchDataset(l_processed, lY, split=split)
|
||||
return DataLoader(
|
||||
self.datasets[split],
|
||||
batch_size=batch_size,
|
||||
shuffle=shuffle,
|
||||
# collate_fn=processor_fn,
|
||||
)
|
||||
|
||||
def get_train_val_data(self, lX, lY, split=0.2, seed=42):
|
||||
def get_train_val_data(self, lX, lY, split=0.2, seed=42, modality="text"):
|
||||
assert modality in ["text", "image"], "modality must be either text or image"
|
||||
tr_lX, tr_lY, val_lX, val_lY = {}, {}, {}, {}
|
||||
|
||||
for lang in lX.keys():
|
||||
tr_X, val_X, tr_Y, val_Y = train_test_split(
|
||||
lX[lang], lY[lang], test_size=split, random_state=seed, shuffle=False
|
||||
lX[lang][modality],
|
||||
lY[lang],
|
||||
test_size=split,
|
||||
random_state=seed,
|
||||
shuffle=False,
|
||||
)
|
||||
tr_lX[lang] = tr_X
|
||||
tr_lY[lang] = tr_Y
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
from vgfs.viewGen import ViewGen
|
||||
from vgfs.learners.svms import NaivePolylingualClassifier
|
||||
from vgfs.commons import _normalize
|
||||
from gfun.vgfs.viewGen import ViewGen
|
||||
from gfun.vgfs.learners.svms import NaivePolylingualClassifier
|
||||
from gfun.vgfs.commons import _normalize
|
||||
|
||||
|
||||
class VanillaFunGen(ViewGen):
|
||||
|
|
|
@ -1,16 +1,15 @@
|
|||
import sys, os
|
||||
|
||||
sys.path.append(os.getcwd())
|
||||
from collections import defaultdict
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import transformers
|
||||
from gfun.vgfs.viewGen import ViewGen
|
||||
from transformers import AutoImageProcessor
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
from torchvision.transforms import RandomResizedCrop, Compose, Normalize, ToTensor
|
||||
from gfun.vgfs.commons import Trainer, predict
|
||||
from PIL import Image
|
||||
from torch.utils.data import Dataset
|
||||
from transformers import AutoImageProcessor, AutoModelForImageClassification
|
||||
|
||||
from gfun.vgfs.commons import Trainer
|
||||
from gfun.vgfs.transformerGen import TransformerGen
|
||||
from transformers import AutoModelForImageClassification
|
||||
from gfun.vgfs.viewGen import ViewGen
|
||||
|
||||
transformers.logging.set_verbosity_error()
|
||||
|
||||
|
@ -26,6 +25,7 @@ class VisualTransformerGen(ViewGen, TransformerGen):
|
|||
batch_size_eval=128,
|
||||
evaluate_step=10,
|
||||
device="cpu",
|
||||
probabilistic=False,
|
||||
patience=5,
|
||||
):
|
||||
super().__init__(
|
||||
|
@ -38,6 +38,7 @@ class VisualTransformerGen(ViewGen, TransformerGen):
|
|||
device=device,
|
||||
evaluate_step=evaluate_step,
|
||||
patience=patience,
|
||||
probabilistic=probabilistic,
|
||||
)
|
||||
self.fitted = False
|
||||
print(
|
||||
|
@ -52,44 +53,28 @@ class VisualTransformerGen(ViewGen, TransformerGen):
|
|||
|
||||
def init_model(self, model_name, num_labels):
|
||||
model = AutoModelForImageClassification.from_pretrained(
|
||||
model_name, num_labels=num_labels
|
||||
model_name, num_labels=num_labels, output_hidden_states=True
|
||||
)
|
||||
image_processor = AutoImageProcessor.from_pretrained(model_name)
|
||||
transforms = self.init_preprocessor(image_processor)
|
||||
return model, image_processor, transforms
|
||||
|
||||
def init_preprocessor(self, image_processor):
|
||||
normalize = Normalize(
|
||||
mean=image_processor.image_mean, std=image_processor.image_std
|
||||
)
|
||||
size = (
|
||||
image_processor.size["shortest_edge"]
|
||||
if "shortest_edge" in image_processor.size
|
||||
else (image_processor.size["height"], image_processor.size["width"])
|
||||
)
|
||||
# these are the transformations that we are applying to the images
|
||||
transforms = Compose([RandomResizedCrop(size), ToTensor(), normalize])
|
||||
return transforms
|
||||
|
||||
def preprocess(self, images, transforms):
|
||||
processed = transforms(img.convert("RGB") for img in images)
|
||||
return processed
|
||||
return model, image_processor
|
||||
|
||||
def process_all(self, X):
|
||||
# TODO: every element in X is a tuple (doc_id, clean_text, text, Pil.Image), so we're taking just the last element for processing
|
||||
processed = torch.stack([self.transforms(img[-1]) for img in X])
|
||||
return processed
|
||||
# TODO: should be moved as a collate_fn to avoid this overhead
|
||||
processed = self.image_preprocessor(
|
||||
[Image.open(img).convert("RGB") for img in X], return_tensors="pt"
|
||||
)
|
||||
return processed["pixel_values"]
|
||||
|
||||
def fit(self, lX, lY):
|
||||
print("- fitting Visual Transformer View Generating Function")
|
||||
_l = list(lX.keys())[0]
|
||||
self.num_labels = lY[_l].shape[-1]
|
||||
self.model, self.image_preprocessor, self.transforms = self.init_model(
|
||||
self.model, self.image_preprocessor = self.init_model(
|
||||
self._validate_model_name(self.model_name), num_labels=self.num_labels
|
||||
)
|
||||
|
||||
tr_lX, tr_lY, val_lX, val_lY = self.get_train_val_data(
|
||||
lX, lY, split=0.2, seed=42
|
||||
lX, lY, split=0.2, seed=42, modality="image"
|
||||
)
|
||||
|
||||
tra_dataloader = self.build_dataloader(
|
||||
|
@ -135,17 +120,64 @@ class VisualTransformerGen(ViewGen, TransformerGen):
|
|||
if self.probabilistic:
|
||||
self.feature2posterior_projector.fit(self.transform(lX), lY)
|
||||
|
||||
self.fitted = True
|
||||
|
||||
return self
|
||||
|
||||
def transform(self, lX):
|
||||
raise NotImplementedError
|
||||
# forcing to only image modality
|
||||
lX = {lang: data["image"] for lang, data in lX.items()}
|
||||
_embeds = []
|
||||
l_embeds = defaultdict(list)
|
||||
|
||||
dataloader = self.build_dataloader(
|
||||
lX,
|
||||
lY=None,
|
||||
processor_fn=self.process_all,
|
||||
torchDataset=MultimodalDatasetTorch,
|
||||
batch_size=self.batch_size_eval,
|
||||
split="whole",
|
||||
shuffle=False,
|
||||
)
|
||||
|
||||
self.model.eval()
|
||||
with torch.no_grad():
|
||||
for input_ids, lang in dataloader:
|
||||
input_ids = input_ids.to(self.device)
|
||||
out = self.model(input_ids).hidden_states[-1]
|
||||
batch_embeddings = out[:, 0, :].cpu().numpy()
|
||||
_embeds.append((batch_embeddings, lang))
|
||||
|
||||
for embed, lang in _embeds:
|
||||
for sample_embed, sample_lang in zip(embed, lang):
|
||||
l_embeds[sample_lang].append(sample_embed)
|
||||
|
||||
if self.probabilistic and self.fitted:
|
||||
l_embeds = self.feature2posterior_projector.transform(l_embeds)
|
||||
elif not self.probabilistic and self.fitted:
|
||||
l_embeds = {lang: np.array(preds) for lang, preds in l_embeds.items()}
|
||||
|
||||
return l_embeds
|
||||
|
||||
def fit_transform(self, lX, lY):
|
||||
raise NotImplementedError
|
||||
return self.fit(lX, lY).transform(lX)
|
||||
|
||||
def save_vgf(self, model_id):
|
||||
raise NotImplementedError
|
||||
import pickle
|
||||
from os import makedirs
|
||||
from os.path import join
|
||||
|
||||
def save_vgf(self, model_id):
|
||||
raise NotImplementedError
|
||||
vgf_name = "visualTransformerGen"
|
||||
_basedir = join("models", "vgfs", "visual_transformer")
|
||||
makedirs(_basedir, exist_ok=True)
|
||||
_path = join(_basedir, f"{vgf_name}_{model_id}.pkl")
|
||||
with open(_path, "wb") as f:
|
||||
pickle.dump(self, f)
|
||||
return self
|
||||
|
||||
def __str__(self):
|
||||
str = f"[Visual Transformer VGF (v)]\n- model_name: {self.model_name}\n- batch_size: {self.batch_size}\n- batch_size_eval: {self.batch_size_eval}\n- lr: {self.lr}\n- epochs: {self.epochs}\n- device: {self.device}\n- print_steps: {self.print_steps}\n- evaluate_step: {self.evaluate_step}\n- patience: {self.patience}\n- probabilistic: {self.probabilistic}\n"
|
||||
return str
|
||||
|
||||
|
||||
class MultimodalDatasetTorch(Dataset):
|
||||
|
@ -169,7 +201,6 @@ class MultimodalDatasetTorch(Dataset):
|
|||
],
|
||||
[],
|
||||
)
|
||||
# print(f"- lX has shape: {self.X.shape}\n- lY has shape: {self.Y.shape}")
|
||||
|
||||
def __len__(self):
|
||||
return len(self.X)
|
||||
|
@ -182,15 +213,28 @@ class MultimodalDatasetTorch(Dataset):
|
|||
|
||||
if __name__ == "__main__":
|
||||
from os.path import expanduser
|
||||
from dataManager.multiNewsDataset import MultiNewsDataset
|
||||
|
||||
_dataset_path_hardcoded = "~/datasets/MultiNews/20110730/"
|
||||
from dataManager.gFunDataset import gFunDataset
|
||||
|
||||
dataset = MultiNewsDataset(expanduser(_dataset_path_hardcoded), debug=True)
|
||||
lXtr, lYtr = dataset.training()
|
||||
GLAMI_DATAPATH = expanduser("~/datasets/GLAMI-1M-dataset")
|
||||
dataset = gFunDataset(
|
||||
dataset_dir=GLAMI_DATAPATH,
|
||||
is_textual=True,
|
||||
is_visual=True,
|
||||
is_multilabel=False,
|
||||
nrows=50,
|
||||
)
|
||||
|
||||
vg = VisualTransformerGen(
|
||||
model_name="vit", device="cuda", epochs=1000, evaluate_step=10, patience=100
|
||||
dataset_name=dataset.dataset_name,
|
||||
model_name="vit",
|
||||
device="cuda",
|
||||
epochs=5,
|
||||
evaluate_step=10,
|
||||
patience=10,
|
||||
probabilistic=True,
|
||||
)
|
||||
lX, lY = dataset.training()
|
||||
vg.fit(lX, lY)
|
||||
out = vg.transform(lX)
|
||||
exit(0)
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
import numpy as np
|
||||
from joblib import Parallel, delayed
|
||||
from vgfs.commons import XdotM, _normalize
|
||||
from vgfs.viewGen import ViewGen
|
||||
from gfun.vgfs.commons import XdotM, _normalize
|
||||
from gfun.vgfs.viewGen import ViewGen
|
||||
|
||||
|
||||
class WceGen(ViewGen):
|
||||
|
|
52
main.py
52
main.py
|
@ -7,6 +7,7 @@ from dataManager.amazonDataset import AmazonDataset
|
|||
from dataManager.multilingualDatset import MultilingualDataset
|
||||
from dataManager.multiNewsDataset import MultiNewsDataset
|
||||
from dataManager.glamiDataset import GlamiDataset
|
||||
from dataManager.gFunDataset import gFunDataset
|
||||
from evaluation.evaluate import evaluate, log_eval
|
||||
from gfun.generalizedFunnelling import GeneralizedFunnelling
|
||||
|
||||
|
@ -44,11 +45,15 @@ def get_dataset(datasetname):
|
|||
GLAMI_DATAPATH = expanduser("~/datasets/GLAMI-1M-dataset")
|
||||
|
||||
if datasetname == "multinews":
|
||||
# TODO: convert to gFunDataset
|
||||
raise NotImplementedError
|
||||
dataset = MultiNewsDataset(
|
||||
expanduser(MULTINEWS_DATAPATH),
|
||||
excluded_langs=["ar", "pe", "pl", "tr", "ua"],
|
||||
)
|
||||
elif datasetname == "amazon":
|
||||
# TODO: convert to gFunDataset
|
||||
raise NotImplementedError
|
||||
dataset = AmazonDataset(
|
||||
domains=args.domains,
|
||||
nrows=args.nrows,
|
||||
|
@ -56,12 +61,21 @@ def get_dataset(datasetname):
|
|||
max_labels=args.max_labels,
|
||||
)
|
||||
elif datasetname == "rcv1-2":
|
||||
dataset = MultilingualDataset(dataset_name="rcv1-2").load(RCV_DATAPATH)
|
||||
if args.nrows is not None:
|
||||
dataset.reduce_data(langs=["en", "it", "fr"], maxn=args.nrows)
|
||||
dataset = gFunDataset(
|
||||
dataset_dir=RCV_DATAPATH,
|
||||
is_textual=True,
|
||||
is_visual=False,
|
||||
is_multilabel=True,
|
||||
nrows=args.nrows,
|
||||
)
|
||||
elif datasetname == "glami":
|
||||
dataset = GlamiDataset(dataset_dir=GLAMI_DATAPATH, nrows=args.nrows)
|
||||
dataset.build_dataset()
|
||||
dataset = gFunDataset(
|
||||
dataset_dir=GLAMI_DATAPATH,
|
||||
is_textual=True,
|
||||
is_visual=True,
|
||||
is_multilabel=False,
|
||||
nrows=args.nrows,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
return dataset
|
||||
|
@ -73,6 +87,7 @@ def main(args):
|
|||
isinstance(dataset, MultilingualDataset)
|
||||
or isinstance(dataset, MultiNewsDataset)
|
||||
or isinstance(dataset, GlamiDataset)
|
||||
or isinstance(dataset, gFunDataset)
|
||||
):
|
||||
lX, lY = dataset.training()
|
||||
lX_te, lY_te = dataset.test()
|
||||
|
@ -89,7 +104,8 @@ def main(args):
|
|||
args.wce,
|
||||
args.multilingual,
|
||||
args.multilingual,
|
||||
args.transformer,
|
||||
args.textual_transformer,
|
||||
args.visual_transformer,
|
||||
]
|
||||
), "At least one of VGF must be True"
|
||||
|
||||
|
@ -106,8 +122,8 @@ def main(args):
|
|||
# WCE VGF params ----------------------
|
||||
wce=args.wce,
|
||||
# Transformer VGF params --------------
|
||||
transformer=args.transformer,
|
||||
transformer_name=args.transformer_name,
|
||||
textual_transformer=args.textual_transformer,
|
||||
textual_transformer_name=args.transformer_name,
|
||||
batch_size=args.batch_size,
|
||||
epochs=args.epochs,
|
||||
lr=args.lr,
|
||||
|
@ -115,6 +131,15 @@ def main(args):
|
|||
patience=args.patience,
|
||||
evaluate_step=args.evaluate_step,
|
||||
device="cuda",
|
||||
# Visual Transformer VGF params --------------
|
||||
visual_transformer=args.visual_transformer,
|
||||
visual_transformer_name=args.visual_transformer_name,
|
||||
# batch_size=args.batch_size,
|
||||
# epochs=args.epochs,
|
||||
# lr=args.lr,
|
||||
# patience=args.patience,
|
||||
# evaluate_step=args.evaluate_step,
|
||||
# device="cuda",
|
||||
# General params ----------------------
|
||||
probabilistic=args.features,
|
||||
aggfunc=args.aggfunc,
|
||||
|
@ -152,7 +177,7 @@ if __name__ == "__main__":
|
|||
parser.add_argument("--meta", action="store_true")
|
||||
parser.add_argument("--nosave", action="store_true")
|
||||
# Dataset parameters -------------------
|
||||
parser.add_argument("-d", "--dataset", type=str, default="multinews")
|
||||
parser.add_argument("-d", "--dataset", type=str, default="rcv1-2")
|
||||
parser.add_argument("--domains", type=str, default="all")
|
||||
parser.add_argument("--nrows", type=int, default=None)
|
||||
parser.add_argument("--min_count", type=int, default=10)
|
||||
|
@ -161,7 +186,8 @@ if __name__ == "__main__":
|
|||
parser.add_argument("-p", "--posteriors", action="store_true")
|
||||
parser.add_argument("-m", "--multilingual", action="store_true")
|
||||
parser.add_argument("-w", "--wce", action="store_true")
|
||||
parser.add_argument("-t", "--transformer", action="store_true")
|
||||
parser.add_argument("-t", "--textual_transformer", action="store_true")
|
||||
parser.add_argument("-v", "--visual_transformer", action="store_true")
|
||||
parser.add_argument("--n_jobs", type=int, default=-1)
|
||||
parser.add_argument("--optimc", action="store_true")
|
||||
parser.add_argument("--features", action="store_false")
|
||||
|
@ -169,11 +195,13 @@ if __name__ == "__main__":
|
|||
# transformer parameters ---------------
|
||||
parser.add_argument("--transformer_name", type=str, default="mbert")
|
||||
parser.add_argument("--batch_size", type=int, default=32)
|
||||
parser.add_argument("--epochs", type=int, default=1000)
|
||||
parser.add_argument("--epochs", type=int, default=100)
|
||||
parser.add_argument("--lr", type=float, default=1e-5)
|
||||
parser.add_argument("--max_length", type=int, default=512)
|
||||
parser.add_argument("--max_length", type=int, default=128)
|
||||
parser.add_argument("--patience", type=int, default=5)
|
||||
parser.add_argument("--evaluate_step", type=int, default=10)
|
||||
# Visual Transformer parameters --------------
|
||||
parser.add_argument("--visual_transformer_name", type=str, default="vit")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
|
|
Loading…
Reference in New Issue