implemented multimodal pipeline; gFunDataset interface; fixed imports

This commit is contained in:
Andrea Pedrotti 2023-03-02 18:16:46 +01:00
parent 7041f7b651
commit 0c9454cdd4
16 changed files with 480 additions and 117 deletions

1
.gitignore vendored
View File

@ -180,3 +180,4 @@ amazon_cateogories.bu.txt
models/*
scripts/
logger/*
explore_data.ipynb

223
dataManager/gFunDataset.py Normal file
View File

@ -0,0 +1,223 @@
from sklearn.preprocessing import MultiLabelBinarizer, LabelBinarizer
from dataManager.glamiDataset import get_dataframe
from dataManager.multilingualDatset import MultilingualDataset
class gFunDataset:
def __init__(
self,
dataset_dir,
is_textual,
is_visual,
is_multilabel,
labels=None,
nrows=None,
data_langs=None,
):
self.dataset_dir = dataset_dir
self.data_langs = data_langs
self.is_textual = is_textual
self.is_visual = is_visual
self.is_multilabel = is_multilabel
self.labels = labels
self.nrows = nrows
self.dataset = {}
self.load_dataset()
def get_label_binarizer(self, labels):
if self.dataset_name in ["rcv1-2", "jrc"]:
mlb = "Labels are already binarized for rcv1-2 dataset"
elif self.is_multilabel:
mlb = MultiLabelBinarizer()
mlb.fit([labels])
else:
mlb = LabelBinarizer()
mlb.fit(labels)
return mlb
def load_dataset(self):
if "glami" in self.dataset_dir.lower():
print(f"- Loading GLAMI dataset from {self.dataset_dir}")
self.dataset_name = "glami"
self.dataset, self.labels, self.data_langs = self._load_glami(
self.dataset_dir, self.nrows
)
self.mlb = self.get_label_binarizer(self.labels)
elif "rcv" in self.dataset_dir.lower():
print(f"- Loading RCV1-2 dataset from {self.dataset_dir}")
self.dataset_name = "rcv1-2"
self.dataset, self.labels, self.data_langs = self._load_multilingual(
self.dataset_name, self.dataset_dir, self.nrows
)
self.mlb = self.get_label_binarizer(self.labels)
elif "jrc" in self.dataset_dir.lower():
print(f"- Loading JRC dataset from {self.dataset_dir}")
self.dataset_name = "jrc"
self.dataset, self.labels, self.data_langs = self._load_multilingual(
self.dataset_name, self.dataset_dir, self.nrows
)
self.mlb = self.get_label_binarizer(self.labels)
self.show_dimension()
return
def show_dimension(self):
print(f"\n[Dataset: {self.dataset_name.upper()}]")
for lang, data in self.dataset.items():
print(
f"-- Lang: {lang} - train docs: {len(data['train']['text'])} - test docs: {len(data['test']['text'])}"
)
if self.dataset_name in ["rcv1-2", "jrc"]:
print(f"-- Labels: {self.labels}")
else:
print(f"-- Labels: {len(self.labels)}")
def _load_multilingual(self, dataset_name, dataset_dir, nrows):
old_dataset = MultilingualDataset(dataset_name=dataset_name).load(dataset_dir)
if nrows is not None:
old_dataset.reduce_data(langs=["en", "it", "fr"], maxn=nrows)
labels = old_dataset.num_labels()
data_langs = old_dataset.langs()
def _format_multilingual(data):
text = data[0]
image = None
labels = data[1]
return {"text": text, "image": image, "label": labels}
dataset = {
k: {"train": _format_multilingual(v[0]), "test": _format_multilingual(v[1])}
for k, v in old_dataset.multiling_dataset.items()
}
return dataset, labels, data_langs
def _load_glami(self, dataset_dir, nrows):
def _balanced_sample(data, n, remainder=0):
import pandas as pd
langs = sorted(data.geo.unique().tolist())
dict_n = {lang: n for lang in langs}
dict_n[langs[0]] += remainder
sampled = []
for lang in langs:
sampled.append(data[data.geo == lang].sample(n=dict_n[lang]))
return pd.concat(sampled, axis=0)
# TODO: set this sampling as determinsitic/dependeing on the seed
lang_nrows = (
nrows // 13 if self.data_langs is None else nrows // len(self.data_langs)
) # GLAMI 1-M has 13 languages
remainder = (
nrows % 13 if self.data_langs is None else nrows % len(self.data_langs)
)
train_split = get_dataframe("train", dataset_dir=dataset_dir)
train_split = _balanced_sample(train_split, lang_nrows, remainder=remainder)
if self.data_langs is None:
data_langs = sorted(train_split.geo.unique().tolist())
# TODO: if data langs is NOT none then we have a problem where we filter df by langs
if self.labels is None:
labels = train_split.category_name.unique().tolist()
# TODO: atm test data should contain same languages as train data
test_split = get_dataframe("test", dataset_dir=dataset_dir)
# TODO: atm we're using 1:1 train-test
test_split = _balanced_sample(test_split, lang_nrows, remainder=remainder)
gb_train = train_split.groupby("geo")
gb_test = test_split.groupby("geo")
def _format_glami(data_df):
text = (data_df.name + " " + data_df.description).tolist()
image = data_df.image_file.tolist()
labels = data_df.category_name.tolist()
return {"text": text, "image": image, "label": labels}
dataset = {
lang: {
"train": _format_glami(data_tr),
"test": _format_glami(gb_test.get_group(lang)),
}
for lang, data_tr in gb_train
if lang in data_langs
}
return dataset, labels, data_langs
def binarize_labels(self, labels):
if self.dataset_name in ["rcv1-2", "jrc"]:
# labels are already binarized for rcv1-2 dataset
return labels
if hasattr(self, "mlb"):
return self.mlb.transform(labels)
else:
raise AttributeError("Label binarizer not found")
def training(self):
lXtr = {}
lYtr = {}
for lang in self.data_langs:
text = self.dataset[lang]["train"]["text"] if self.is_textual else None
img = self.dataset[lang]["train"]["image"] if self.is_visual else None
labels = self.dataset[lang]["train"]["label"]
lXtr[lang] = {"text": text, "image": img}
lYtr[lang] = self.binarize_labels(labels)
return lXtr, lYtr
def test(self):
lXte = {}
lYte = {}
for lang in self.data_langs:
text = self.dataset[lang]["test"]["text"] if self.is_textual else None
img = self.dataset[lang]["test"]["image"] if self.is_visual else None
labels = self.dataset[lang]["test"]["label"]
lXte[lang] = {"text": text, "image": img}
lYte[lang] = self.binarize_labels(labels)
return lXte, lYte
def langs(self):
return self.data_langs
def num_labels(self):
if self.dataset_name not in ["rcv1-2", "jrc"]:
return len(self.labels)
else:
return self.labels
if __name__ == "__main__":
import os
GLAMI_DATAPATH = os.path.expanduser("~/datasets/GLAMI-1M-dataset")
RCV_DATAPATH = os.path.expanduser(
"~/datasets/rcv1-2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle"
)
JRC_DATAPATH = os.path.expanduser(
"~/datasets/jrc/jrc_doclist_1958-2005vs2006_all_top300_noparallel_processed_run0.pickle"
)
print("Hello gFunDataset")
dataset = gFunDataset(
# dataset_dir=GLAMI_DATAPATH,
# dataset_dir=RCV_DATAPATH,
dataset_dir=JRC_DATAPATH,
data_langs=None,
is_textual=True,
is_visual=True,
is_multilabel=False,
labels=None,
nrows=13,
)
lXtr, lYtr = dataset.training()
lXte, lYte = dataset.test()
exit(0)

View File

@ -100,6 +100,26 @@ class GlamiDataset:
self.nrows = nrows
self.multilingual_dataset = {}
"""
self.multilingual_multimodal_dataset = {
lang: {
"text": {txt_data},
"image": {img_data},
}
}
TODO: if we decide to do this, we need to change both the
training (e.g. vectorizer should call "text") and also the
multilingual unimodal dataset (to include the field "text" only).
BUT this will be a pain when we split/shuffle the datasets.
I think it is better to have smt like this:
self.ml_mm_dataset = {
"lang": (txt_data, img_data)
}
but then also the unimodal dataset should have a "lang": (txt_data, _) value
"""
def num_labels(self):
return len(self.labels)
@ -143,7 +163,7 @@ class GlamiDataset:
def training(self):
# TODO: tolist() or ???
lXtr = {
lang: (df.name + " " + df.description).tolist()
lang: ((df.name + " " + df.description).tolist(), df.image_file.tolist())
for lang, (df, _) in self.multilingual_dataset.items()
}
lYtr = {
@ -154,7 +174,7 @@ class GlamiDataset:
def test(self):
lXte = {
lang: (df.name + " " + df.description).tolist()
lang: ((df.name + " " + df.description).tolist(), df.image_file.tolist())
for lang, (_, df) in self.multilingual_dataset.items()
}
lYte = {

View File

@ -86,7 +86,7 @@ class MultiNewsDataset:
lXtr = {}
lYtr = {}
for lang, data in self.lang_multiModalDataset.items():
_data = [clean_text for _, clean_text, _, _ in data.data]
_data = [(clean_text, img) for _, clean_text, _, img in data.data]
lXtr[lang] = _data
lYtr = {
lang: self.label_binarizer.transform(data.labels)

View File

@ -64,7 +64,7 @@ class MultilingualDataset:
def __init__(self, dataset_name):
self.dataset_name = dataset_name
self.multiling_dataset = {}
print(f"[Init Multilingual Dataset: {self.dataset_name}]")
# print(f"[Init Multilingual Dataset: {self.dataset_name}]")
def add(self, lang, Xtr, Ytr, Xte, Yte, tr_ids=None, te_ids=None):
self.multiling_dataset[lang] = ((Xtr, Ytr, tr_ids), (Xte, Yte, te_ids))

View File

@ -7,9 +7,8 @@ def evaluation_metrics(y, y_):
if len(y.shape) == len(y_.shape) == 1 and len(np.unique(y)) > 2: # single-label
raise NotImplementedError() # return f1_score(y,y_,average='macro'), f1_score(y,y_,average='micro')
else: # the metrics I implemented assume multiclass multilabel classification as binary classifiers
# return macroF1(y, y_), microF1(y, y_), macroK(y, y_), microK(y, y_), macroP(y, y_), microP(y, y_), macroR(y, y_), microR(y, y_)
# return macroF1(y, y_), microF1(y, y_), macroAcc(y, y_), microAcc(y, y_), macroP(y, y_), microP(y, y_), macroR(y, y_), microR(y, y_), macroAcc(y, y_)
return macroF1(y, y_), microF1(y, y_), macroK(y, y_), microK(y, y_)
# return macroF1(y, y_), microF1(y, y_), macroK(y, y_), macroAcc(y, y_)
def evaluate(ly_true, ly_pred, metrics=evaluation_metrics, n_jobs=-1):

View File

@ -235,3 +235,7 @@ def macroK(true_labels, predicted_labels):
# true_labels and predicted_labels are two matrices in sklearn.preprocessing.MultiLabelBinarizer format
def microK(true_labels, predicted_labels):
return micro_average(true_labels, predicted_labels, K)
def macroAcc(true_labels, predicted_labels):
return macro_average(true_labels, predicted_labels, accuracy)

View File

@ -1,17 +1,18 @@
import os
import sys
sys.path.append(os.path.join(os.getcwd(), "gfun"))
# sys.path.append(os.path.join(os.getcwd(), "gfun"))
import pickle
import numpy as np
from vgfs.commons import TfidfVectorizerMultilingual, AttentionAggregator
from vgfs.learners.svms import MetaClassifier, get_learner
from vgfs.multilingualGen import MultilingualGen
from gfun.vgfs.commons import TfidfVectorizerMultilingual, AttentionAggregator
from gfun.vgfs.learners.svms import MetaClassifier, get_learner
from gfun.vgfs.multilingualGen import MultilingualGen
from gfun.vgfs.textualTransformerGen import TextualTransformerGen
from vgfs.vanillaFun import VanillaFunGen
from vgfs.wceGen import WceGen
from gfun.vgfs.visualTransformerGen import VisualTransformerGen
from gfun.vgfs.vanillaFun import VanillaFunGen
from gfun.vgfs.wceGen import WceGen
class GeneralizedFunnelling:
@ -20,7 +21,8 @@ class GeneralizedFunnelling:
posterior,
wce,
multilingual,
transformer,
textual_transformer,
visual_transformer,
langs,
num_labels,
embed_dir,
@ -31,7 +33,8 @@ class GeneralizedFunnelling:
epochs,
patience,
evaluate_step,
transformer_name,
textual_transformer_name,
visual_transformer_name,
optimc,
device,
load_trained,
@ -44,15 +47,16 @@ class GeneralizedFunnelling:
self.posteriors_vgf = posterior
self.wce_vgf = wce
self.multilingual_vgf = multilingual
self.trasformer_vgf = transformer
self.trasformer_vgf = textual_transformer
self.visual_transformer_vgf = visual_transformer
self.probabilistic = probabilistic
self.num_labels = num_labels
# ------------------------
self.langs = langs
self.embed_dir = embed_dir
self.cached = True
# Transformer VGF params ----------
self.transformer_name = transformer_name
# Textual Transformer VGF params ----------
self.textaul_transformer_name = textual_transformer_name
self.epochs = epochs
self.lr_transformer = lr
self.batch_size_transformer = batch_size
@ -61,6 +65,8 @@ class GeneralizedFunnelling:
self.patience = patience
self.evaluate_step = evaluate_step
self.device = device
# Visual Transformer VGF params ----------
self.visual_transformer_name = visual_transformer_name
# Metaclassifier params ------------
self.optimc = optimc
# -------------------
@ -78,7 +84,7 @@ class GeneralizedFunnelling:
self._init()
def _init(self):
print("[Init GeneralizedFunnelling]")
print("\n[Init GeneralizedFunnelling]")
assert not (
self.aggfunc == "mean" and self.probabilistic is False
), "When using averaging aggreagation function probabilistic must be True"
@ -139,20 +145,35 @@ class GeneralizedFunnelling:
if self.trasformer_vgf:
transformer_vgf = TextualTransformerGen(
dataset_name=self.dataset_name,
model_name=self.transformer_name,
model_name=self.textaul_transformer_name,
lr=self.lr_transformer,
epochs=self.epochs,
batch_size=self.batch_size_transformer,
max_length=self.max_length,
device=self.device,
print_steps=50,
probabilistic=self.probabilistic,
evaluate_step=self.evaluate_step,
verbose=True,
patience=self.patience,
device=self.device,
)
self.first_tier_learners.append(transformer_vgf)
if self.visual_transformer_vgf:
visual_trasformer_vgf = VisualTransformerGen(
dataset_name=self.dataset_name,
model_name="vit",
lr=1e-5, # self.lr_visual_transformer,
epochs=self.epochs,
batch_size=32, # self.batch_size_visual_transformer,
# batch_size_eval=128,
probabilistic=self.probabilistic,
evaluate_step=self.evaluate_step,
patience=self.patience,
device=self.device,
)
self.first_tier_learners.append(visual_trasformer_vgf)
if "attn" in self.aggfunc:
attn_stacking = self.aggfunc.split("_")[1]
self.attn_aggregator = AttentionAggregator(
@ -189,15 +210,18 @@ class GeneralizedFunnelling:
vgf.vectorizer = self.vectorizer
def fit(self, lX, lY):
print("[Fitting GeneralizedFunnelling]")
print("\n[Fitting GeneralizedFunnelling]")
if self.load_trained is not None:
print(
"- loaded first tier learners!"
if self.load_meta is False
else "- loaded trained model!"
)
"""
if we are only loading the first tier, we need to
transform the training data to train the meta-classifier
"""
if self.load_first_tier is True and self.load_meta is False:
# TODO: clean up this code here
projections = []
for vgf in self.first_tier_learners:
l_posteriors = vgf.transform(lX)
@ -403,7 +427,7 @@ def get_unique_id(dataset_name, posterior, multilingual, wce, transformer, aggfu
from datetime import datetime
now = datetime.now().strftime("%y%m%d")
model_id = dataset_name
model_id = f"{dataset_name}_"
model_id += "p" if posterior else ""
model_id += "m" if multilingual else ""
model_id += "w" if wce else ""

View File

@ -4,17 +4,17 @@ from collections import defaultdict
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from sklearn.decomposition import TruncatedSVD
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import normalize
from torch.optim import AdamW
from transformers.modeling_outputs import SequenceClassifierOutput
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader, Dataset
from transformers.modeling_outputs import ModelOutput
from evaluation.evaluate import evaluate, log_eval
PRINT_ON_EPOCH = 10
PRINT_ON_EPOCH = 1
def _normalize(lX, l2=True):
@ -78,12 +78,12 @@ class TfidfVectorizerMultilingual:
def fit(self, lX, ly=None):
self.langs = sorted(lX.keys())
self.vectorizer = {
l: TfidfVectorizer(**self.kwargs).fit(lX[l]) for l in self.langs
l: TfidfVectorizer(**self.kwargs).fit(lX[l]["text"]) for l in self.langs
}
return self
def transform(self, lX):
return {l: self.vectorizer[l].transform(lX[l]) for l in self.langs}
return {l: self.vectorizer[l].transform(lX[l]["text"]) for l in self.langs}
def fit_transform(self, lX, ly=None):
return self.fit(lX, ly).transform(lX)
@ -123,6 +123,7 @@ class Trainer:
self.print_steps = print_steps
self.experiment_name = experiment_name
self.patience = patience
self.print_eval = evaluate_step
self.earlystopping = EarlyStopping(
patience=patience,
checkpoint_path=checkpoint_path,
@ -144,12 +145,15 @@ class Trainer:
- train batch size: {train_dataloader.batch_size}
- eval batch size: {eval_dataloader.batch_size}
- max len: {train_dataloader.dataset.X.shape[-1]}
- patience: {self.earlystopping.patience}\n"""
- patience: {self.earlystopping.patience}
- evaluate every: {self.evaluate_steps}
- print eval every: {self.print_eval}
- print train steps: {self.print_steps}\n"""
)
for epoch in range(epochs):
self.train_epoch(train_dataloader, epoch)
if (epoch + 1) % self.evaluate_steps == 0:
print_eval = (epoch + 1) % 25 == 0
print_eval = (epoch + 1) % self.print_eval == 0
metric_watcher = self.evaluate(eval_dataloader, print_eval=print_eval)
stop = self.earlystopping(metric_watcher, self.model, epoch + 1)
if stop:
@ -160,7 +164,7 @@ class Trainer:
self.device
)
break
print(f"\n- last swipe on eval set")
print(f"- last swipe on eval set")
self.train_epoch(eval_dataloader, epoch=0)
self.earlystopping.save_model(self.model)
return self.model
@ -170,7 +174,7 @@ class Trainer:
for b_idx, (x, y, lang) in enumerate(dataloader):
self.optimizer.zero_grad()
y_hat = self.model(x.to(self.device))
if isinstance(y_hat, SequenceClassifierOutput):
if isinstance(y_hat, ModelOutput):
loss = self.loss_fn(y_hat.logits, y.to(self.device))
else:
loss = self.loss_fn(y_hat, y.to(self.device))
@ -189,7 +193,7 @@ class Trainer:
for b_idx, (x, y, lang) in enumerate(dataloader):
y_hat = self.model(x.to(self.device))
if isinstance(y_hat, SequenceClassifierOutput):
if isinstance(y_hat, ModelOutput):
loss = self.loss_fn(y_hat.logits, y.to(self.device))
predictions = predict(y_hat.logits, classification_type="multilabel")
else:
@ -272,6 +276,10 @@ class AttentionModule(nn.Module):
self.linear = nn.Linear(embed_dim, out_dim)
self.sigmoid = nn.Sigmoid()
def init_weights(self, mode="mean"):
# TODO: add init function of the attention module: either all weights are positive or set to 1/num_classes
raise NotImplementedError
def __call__(self, X):
out, attn_weights = self.attn(query=X, key=X, value=X)
# out = self.layer_norm(out)

View File

@ -4,9 +4,9 @@ import torch
import numpy as np
from torchtext.vocab import Vectors
from joblib import Parallel, delayed
from vgfs.viewGen import ViewGen
from vgfs.commons import _normalize, XdotM
from vgfs.learners.svms import FeatureSet2Posteriors
from gfun.vgfs.viewGen import ViewGen
from gfun.vgfs.commons import _normalize, XdotM
from gfun.vgfs.learners.svms import FeatureSet2Posteriors
class MultilingualGen(ViewGen):

View File

@ -7,14 +7,14 @@ from collections import defaultdict
import numpy as np
import torch
import transformers
from sklearn.model_selection import train_test_split
from torch.optim import AdamW
from torch.utils.data import DataLoader, Dataset
# from sklearn.model_selection import train_test_split
# from torch.optim import AdamW
from torch.utils.data import Dataset
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from vgfs.learners.svms import FeatureSet2Posteriors
from vgfs.viewGen import ViewGen
from vgfs.transformerGen import TransformerGen
from vgfs.commons import Trainer, predict
from gfun.vgfs.commons import Trainer
from gfun.vgfs.transformerGen import TransformerGen
from gfun.vgfs.viewGen import ViewGen
transformers.logging.set_verbosity_error()
@ -104,7 +104,7 @@ class TextualTransformerGen(ViewGen, TransformerGen):
)
tr_lX, tr_lY, val_lX, val_lY = self.get_train_val_data(
lX, lY, split=0.2, seed=42
lX, lY, split=0.2, seed=42, modality="text"
)
tra_dataloader = self.build_dataloader(
@ -156,6 +156,8 @@ class TextualTransformerGen(ViewGen, TransformerGen):
return self
def transform(self, lX):
# forcing to only text modality
lX = {lang: data["text"] for lang, data in lX.items()}
_embeds = []
l_embeds = defaultdict(list)
@ -196,8 +198,8 @@ class TextualTransformerGen(ViewGen, TransformerGen):
from os import makedirs
from os.path import join
vgf_name = "transformerGen"
_basedir = join("models", "vgfs", "transformer")
vgf_name = "textualTransformerGen"
_basedir = join("models", "vgfs", "textual_transformer")
makedirs(_basedir, exist_ok=True)
_path = join(_basedir, f"{vgf_name}_{model_id}.pkl")
with open(_path, "wb") as f:

View File

@ -1,6 +1,6 @@
from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset, DataLoader
from vgfs.learners.svms import FeatureSet2Posteriors
from gfun.vgfs.learners.svms import FeatureSet2Posteriors
class TransformerGen:
@ -67,16 +67,26 @@ class TransformerGen:
split="train",
shuffle=False,
):
l_tokenized = {lang: processor_fn(data) for lang, data in lX.items()}
self.datasets[split] = torchDataset(l_tokenized, lY, split=split)
return DataLoader(self.datasets[split], batch_size=batch_size, shuffle=shuffle)
l_processed = {lang: processor_fn(lX[lang]) for lang in lX.keys()}
self.datasets[split] = torchDataset(l_processed, lY, split=split)
return DataLoader(
self.datasets[split],
batch_size=batch_size,
shuffle=shuffle,
# collate_fn=processor_fn,
)
def get_train_val_data(self, lX, lY, split=0.2, seed=42):
def get_train_val_data(self, lX, lY, split=0.2, seed=42, modality="text"):
assert modality in ["text", "image"], "modality must be either text or image"
tr_lX, tr_lY, val_lX, val_lY = {}, {}, {}, {}
for lang in lX.keys():
tr_X, val_X, tr_Y, val_Y = train_test_split(
lX[lang], lY[lang], test_size=split, random_state=seed, shuffle=False
lX[lang][modality],
lY[lang],
test_size=split,
random_state=seed,
shuffle=False,
)
tr_lX[lang] = tr_X
tr_lY[lang] = tr_Y

View File

@ -1,6 +1,6 @@
from vgfs.viewGen import ViewGen
from vgfs.learners.svms import NaivePolylingualClassifier
from vgfs.commons import _normalize
from gfun.vgfs.viewGen import ViewGen
from gfun.vgfs.learners.svms import NaivePolylingualClassifier
from gfun.vgfs.commons import _normalize
class VanillaFunGen(ViewGen):

View File

@ -1,16 +1,15 @@
import sys, os
sys.path.append(os.getcwd())
from collections import defaultdict
import numpy as np
import torch
import transformers
from gfun.vgfs.viewGen import ViewGen
from transformers import AutoImageProcessor
from torch.utils.data import DataLoader, Dataset
from torchvision.transforms import RandomResizedCrop, Compose, Normalize, ToTensor
from gfun.vgfs.commons import Trainer, predict
from PIL import Image
from torch.utils.data import Dataset
from transformers import AutoImageProcessor, AutoModelForImageClassification
from gfun.vgfs.commons import Trainer
from gfun.vgfs.transformerGen import TransformerGen
from transformers import AutoModelForImageClassification
from gfun.vgfs.viewGen import ViewGen
transformers.logging.set_verbosity_error()
@ -26,6 +25,7 @@ class VisualTransformerGen(ViewGen, TransformerGen):
batch_size_eval=128,
evaluate_step=10,
device="cpu",
probabilistic=False,
patience=5,
):
super().__init__(
@ -38,6 +38,7 @@ class VisualTransformerGen(ViewGen, TransformerGen):
device=device,
evaluate_step=evaluate_step,
patience=patience,
probabilistic=probabilistic,
)
self.fitted = False
print(
@ -52,44 +53,28 @@ class VisualTransformerGen(ViewGen, TransformerGen):
def init_model(self, model_name, num_labels):
model = AutoModelForImageClassification.from_pretrained(
model_name, num_labels=num_labels
model_name, num_labels=num_labels, output_hidden_states=True
)
image_processor = AutoImageProcessor.from_pretrained(model_name)
transforms = self.init_preprocessor(image_processor)
return model, image_processor, transforms
def init_preprocessor(self, image_processor):
normalize = Normalize(
mean=image_processor.image_mean, std=image_processor.image_std
)
size = (
image_processor.size["shortest_edge"]
if "shortest_edge" in image_processor.size
else (image_processor.size["height"], image_processor.size["width"])
)
# these are the transformations that we are applying to the images
transforms = Compose([RandomResizedCrop(size), ToTensor(), normalize])
return transforms
def preprocess(self, images, transforms):
processed = transforms(img.convert("RGB") for img in images)
return processed
return model, image_processor
def process_all(self, X):
# TODO: every element in X is a tuple (doc_id, clean_text, text, Pil.Image), so we're taking just the last element for processing
processed = torch.stack([self.transforms(img[-1]) for img in X])
return processed
# TODO: should be moved as a collate_fn to avoid this overhead
processed = self.image_preprocessor(
[Image.open(img).convert("RGB") for img in X], return_tensors="pt"
)
return processed["pixel_values"]
def fit(self, lX, lY):
print("- fitting Visual Transformer View Generating Function")
_l = list(lX.keys())[0]
self.num_labels = lY[_l].shape[-1]
self.model, self.image_preprocessor, self.transforms = self.init_model(
self.model, self.image_preprocessor = self.init_model(
self._validate_model_name(self.model_name), num_labels=self.num_labels
)
tr_lX, tr_lY, val_lX, val_lY = self.get_train_val_data(
lX, lY, split=0.2, seed=42
lX, lY, split=0.2, seed=42, modality="image"
)
tra_dataloader = self.build_dataloader(
@ -135,17 +120,64 @@ class VisualTransformerGen(ViewGen, TransformerGen):
if self.probabilistic:
self.feature2posterior_projector.fit(self.transform(lX), lY)
self.fitted = True
return self
def transform(self, lX):
raise NotImplementedError
# forcing to only image modality
lX = {lang: data["image"] for lang, data in lX.items()}
_embeds = []
l_embeds = defaultdict(list)
dataloader = self.build_dataloader(
lX,
lY=None,
processor_fn=self.process_all,
torchDataset=MultimodalDatasetTorch,
batch_size=self.batch_size_eval,
split="whole",
shuffle=False,
)
self.model.eval()
with torch.no_grad():
for input_ids, lang in dataloader:
input_ids = input_ids.to(self.device)
out = self.model(input_ids).hidden_states[-1]
batch_embeddings = out[:, 0, :].cpu().numpy()
_embeds.append((batch_embeddings, lang))
for embed, lang in _embeds:
for sample_embed, sample_lang in zip(embed, lang):
l_embeds[sample_lang].append(sample_embed)
if self.probabilistic and self.fitted:
l_embeds = self.feature2posterior_projector.transform(l_embeds)
elif not self.probabilistic and self.fitted:
l_embeds = {lang: np.array(preds) for lang, preds in l_embeds.items()}
return l_embeds
def fit_transform(self, lX, lY):
raise NotImplementedError
return self.fit(lX, lY).transform(lX)
def save_vgf(self, model_id):
raise NotImplementedError
import pickle
from os import makedirs
from os.path import join
def save_vgf(self, model_id):
raise NotImplementedError
vgf_name = "visualTransformerGen"
_basedir = join("models", "vgfs", "visual_transformer")
makedirs(_basedir, exist_ok=True)
_path = join(_basedir, f"{vgf_name}_{model_id}.pkl")
with open(_path, "wb") as f:
pickle.dump(self, f)
return self
def __str__(self):
str = f"[Visual Transformer VGF (v)]\n- model_name: {self.model_name}\n- batch_size: {self.batch_size}\n- batch_size_eval: {self.batch_size_eval}\n- lr: {self.lr}\n- epochs: {self.epochs}\n- device: {self.device}\n- print_steps: {self.print_steps}\n- evaluate_step: {self.evaluate_step}\n- patience: {self.patience}\n- probabilistic: {self.probabilistic}\n"
return str
class MultimodalDatasetTorch(Dataset):
@ -169,7 +201,6 @@ class MultimodalDatasetTorch(Dataset):
],
[],
)
# print(f"- lX has shape: {self.X.shape}\n- lY has shape: {self.Y.shape}")
def __len__(self):
return len(self.X)
@ -182,15 +213,28 @@ class MultimodalDatasetTorch(Dataset):
if __name__ == "__main__":
from os.path import expanduser
from dataManager.multiNewsDataset import MultiNewsDataset
_dataset_path_hardcoded = "~/datasets/MultiNews/20110730/"
from dataManager.gFunDataset import gFunDataset
dataset = MultiNewsDataset(expanduser(_dataset_path_hardcoded), debug=True)
lXtr, lYtr = dataset.training()
GLAMI_DATAPATH = expanduser("~/datasets/GLAMI-1M-dataset")
dataset = gFunDataset(
dataset_dir=GLAMI_DATAPATH,
is_textual=True,
is_visual=True,
is_multilabel=False,
nrows=50,
)
vg = VisualTransformerGen(
model_name="vit", device="cuda", epochs=1000, evaluate_step=10, patience=100
dataset_name=dataset.dataset_name,
model_name="vit",
device="cuda",
epochs=5,
evaluate_step=10,
patience=10,
probabilistic=True,
)
lX, lY = dataset.training()
vg.fit(lX, lY)
out = vg.transform(lX)
exit(0)

View File

@ -1,7 +1,7 @@
import numpy as np
from joblib import Parallel, delayed
from vgfs.commons import XdotM, _normalize
from vgfs.viewGen import ViewGen
from gfun.vgfs.commons import XdotM, _normalize
from gfun.vgfs.viewGen import ViewGen
class WceGen(ViewGen):

52
main.py
View File

@ -7,6 +7,7 @@ from dataManager.amazonDataset import AmazonDataset
from dataManager.multilingualDatset import MultilingualDataset
from dataManager.multiNewsDataset import MultiNewsDataset
from dataManager.glamiDataset import GlamiDataset
from dataManager.gFunDataset import gFunDataset
from evaluation.evaluate import evaluate, log_eval
from gfun.generalizedFunnelling import GeneralizedFunnelling
@ -44,11 +45,15 @@ def get_dataset(datasetname):
GLAMI_DATAPATH = expanduser("~/datasets/GLAMI-1M-dataset")
if datasetname == "multinews":
# TODO: convert to gFunDataset
raise NotImplementedError
dataset = MultiNewsDataset(
expanduser(MULTINEWS_DATAPATH),
excluded_langs=["ar", "pe", "pl", "tr", "ua"],
)
elif datasetname == "amazon":
# TODO: convert to gFunDataset
raise NotImplementedError
dataset = AmazonDataset(
domains=args.domains,
nrows=args.nrows,
@ -56,12 +61,21 @@ def get_dataset(datasetname):
max_labels=args.max_labels,
)
elif datasetname == "rcv1-2":
dataset = MultilingualDataset(dataset_name="rcv1-2").load(RCV_DATAPATH)
if args.nrows is not None:
dataset.reduce_data(langs=["en", "it", "fr"], maxn=args.nrows)
dataset = gFunDataset(
dataset_dir=RCV_DATAPATH,
is_textual=True,
is_visual=False,
is_multilabel=True,
nrows=args.nrows,
)
elif datasetname == "glami":
dataset = GlamiDataset(dataset_dir=GLAMI_DATAPATH, nrows=args.nrows)
dataset.build_dataset()
dataset = gFunDataset(
dataset_dir=GLAMI_DATAPATH,
is_textual=True,
is_visual=True,
is_multilabel=False,
nrows=args.nrows,
)
else:
raise NotImplementedError
return dataset
@ -73,6 +87,7 @@ def main(args):
isinstance(dataset, MultilingualDataset)
or isinstance(dataset, MultiNewsDataset)
or isinstance(dataset, GlamiDataset)
or isinstance(dataset, gFunDataset)
):
lX, lY = dataset.training()
lX_te, lY_te = dataset.test()
@ -89,7 +104,8 @@ def main(args):
args.wce,
args.multilingual,
args.multilingual,
args.transformer,
args.textual_transformer,
args.visual_transformer,
]
), "At least one of VGF must be True"
@ -106,8 +122,8 @@ def main(args):
# WCE VGF params ----------------------
wce=args.wce,
# Transformer VGF params --------------
transformer=args.transformer,
transformer_name=args.transformer_name,
textual_transformer=args.textual_transformer,
textual_transformer_name=args.transformer_name,
batch_size=args.batch_size,
epochs=args.epochs,
lr=args.lr,
@ -115,6 +131,15 @@ def main(args):
patience=args.patience,
evaluate_step=args.evaluate_step,
device="cuda",
# Visual Transformer VGF params --------------
visual_transformer=args.visual_transformer,
visual_transformer_name=args.visual_transformer_name,
# batch_size=args.batch_size,
# epochs=args.epochs,
# lr=args.lr,
# patience=args.patience,
# evaluate_step=args.evaluate_step,
# device="cuda",
# General params ----------------------
probabilistic=args.features,
aggfunc=args.aggfunc,
@ -152,7 +177,7 @@ if __name__ == "__main__":
parser.add_argument("--meta", action="store_true")
parser.add_argument("--nosave", action="store_true")
# Dataset parameters -------------------
parser.add_argument("-d", "--dataset", type=str, default="multinews")
parser.add_argument("-d", "--dataset", type=str, default="rcv1-2")
parser.add_argument("--domains", type=str, default="all")
parser.add_argument("--nrows", type=int, default=None)
parser.add_argument("--min_count", type=int, default=10)
@ -161,7 +186,8 @@ if __name__ == "__main__":
parser.add_argument("-p", "--posteriors", action="store_true")
parser.add_argument("-m", "--multilingual", action="store_true")
parser.add_argument("-w", "--wce", action="store_true")
parser.add_argument("-t", "--transformer", action="store_true")
parser.add_argument("-t", "--textual_transformer", action="store_true")
parser.add_argument("-v", "--visual_transformer", action="store_true")
parser.add_argument("--n_jobs", type=int, default=-1)
parser.add_argument("--optimc", action="store_true")
parser.add_argument("--features", action="store_false")
@ -169,11 +195,13 @@ if __name__ == "__main__":
# transformer parameters ---------------
parser.add_argument("--transformer_name", type=str, default="mbert")
parser.add_argument("--batch_size", type=int, default=32)
parser.add_argument("--epochs", type=int, default=1000)
parser.add_argument("--epochs", type=int, default=100)
parser.add_argument("--lr", type=float, default=1e-5)
parser.add_argument("--max_length", type=int, default=512)
parser.add_argument("--max_length", type=int, default=128)
parser.add_argument("--patience", type=int, default=5)
parser.add_argument("--evaluate_step", type=int, default=10)
# Visual Transformer parameters --------------
parser.add_argument("--visual_transformer_name", type=str, default="vit")
args = parser.parse_args()