diff --git a/.gitignore b/.gitignore index b4f57ee..25278f6 100644 --- a/.gitignore +++ b/.gitignore @@ -179,4 +179,5 @@ out/* amazon_cateogories.bu.txt models/* scripts/ -logger/* \ No newline at end of file +logger/* +explore_data.ipynb \ No newline at end of file diff --git a/dataManager/gFunDataset.py b/dataManager/gFunDataset.py new file mode 100644 index 0000000..5532c6e --- /dev/null +++ b/dataManager/gFunDataset.py @@ -0,0 +1,223 @@ +from sklearn.preprocessing import MultiLabelBinarizer, LabelBinarizer +from dataManager.glamiDataset import get_dataframe +from dataManager.multilingualDatset import MultilingualDataset + + +class gFunDataset: + def __init__( + self, + dataset_dir, + is_textual, + is_visual, + is_multilabel, + labels=None, + nrows=None, + data_langs=None, + ): + self.dataset_dir = dataset_dir + self.data_langs = data_langs + self.is_textual = is_textual + self.is_visual = is_visual + self.is_multilabel = is_multilabel + self.labels = labels + self.nrows = nrows + self.dataset = {} + self.load_dataset() + + def get_label_binarizer(self, labels): + if self.dataset_name in ["rcv1-2", "jrc"]: + mlb = "Labels are already binarized for rcv1-2 dataset" + elif self.is_multilabel: + mlb = MultiLabelBinarizer() + mlb.fit([labels]) + else: + mlb = LabelBinarizer() + mlb.fit(labels) + return mlb + + def load_dataset(self): + if "glami" in self.dataset_dir.lower(): + print(f"- Loading GLAMI dataset from {self.dataset_dir}") + self.dataset_name = "glami" + self.dataset, self.labels, self.data_langs = self._load_glami( + self.dataset_dir, self.nrows + ) + self.mlb = self.get_label_binarizer(self.labels) + + elif "rcv" in self.dataset_dir.lower(): + print(f"- Loading RCV1-2 dataset from {self.dataset_dir}") + self.dataset_name = "rcv1-2" + self.dataset, self.labels, self.data_langs = self._load_multilingual( + self.dataset_name, self.dataset_dir, self.nrows + ) + self.mlb = self.get_label_binarizer(self.labels) + + elif "jrc" in self.dataset_dir.lower(): + print(f"- Loading JRC dataset from {self.dataset_dir}") + self.dataset_name = "jrc" + self.dataset, self.labels, self.data_langs = self._load_multilingual( + self.dataset_name, self.dataset_dir, self.nrows + ) + self.mlb = self.get_label_binarizer(self.labels) + + self.show_dimension() + + return + + def show_dimension(self): + print(f"\n[Dataset: {self.dataset_name.upper()}]") + for lang, data in self.dataset.items(): + print( + f"-- Lang: {lang} - train docs: {len(data['train']['text'])} - test docs: {len(data['test']['text'])}" + ) + if self.dataset_name in ["rcv1-2", "jrc"]: + print(f"-- Labels: {self.labels}") + else: + print(f"-- Labels: {len(self.labels)}") + + def _load_multilingual(self, dataset_name, dataset_dir, nrows): + old_dataset = MultilingualDataset(dataset_name=dataset_name).load(dataset_dir) + if nrows is not None: + old_dataset.reduce_data(langs=["en", "it", "fr"], maxn=nrows) + labels = old_dataset.num_labels() + data_langs = old_dataset.langs() + + def _format_multilingual(data): + text = data[0] + image = None + labels = data[1] + return {"text": text, "image": image, "label": labels} + + dataset = { + k: {"train": _format_multilingual(v[0]), "test": _format_multilingual(v[1])} + for k, v in old_dataset.multiling_dataset.items() + } + return dataset, labels, data_langs + + def _load_glami(self, dataset_dir, nrows): + def _balanced_sample(data, n, remainder=0): + import pandas as pd + + langs = sorted(data.geo.unique().tolist()) + dict_n = {lang: n for lang in langs} + dict_n[langs[0]] += remainder + + sampled = [] + for lang in langs: + sampled.append(data[data.geo == lang].sample(n=dict_n[lang])) + + return pd.concat(sampled, axis=0) + + # TODO: set this sampling as determinsitic/dependeing on the seed + lang_nrows = ( + nrows // 13 if self.data_langs is None else nrows // len(self.data_langs) + ) # GLAMI 1-M has 13 languages + remainder = ( + nrows % 13 if self.data_langs is None else nrows % len(self.data_langs) + ) + + train_split = get_dataframe("train", dataset_dir=dataset_dir) + train_split = _balanced_sample(train_split, lang_nrows, remainder=remainder) + + if self.data_langs is None: + data_langs = sorted(train_split.geo.unique().tolist()) + # TODO: if data langs is NOT none then we have a problem where we filter df by langs + if self.labels is None: + labels = train_split.category_name.unique().tolist() + + # TODO: atm test data should contain same languages as train data + test_split = get_dataframe("test", dataset_dir=dataset_dir) + # TODO: atm we're using 1:1 train-test + test_split = _balanced_sample(test_split, lang_nrows, remainder=remainder) + + gb_train = train_split.groupby("geo") + gb_test = test_split.groupby("geo") + + def _format_glami(data_df): + text = (data_df.name + " " + data_df.description).tolist() + image = data_df.image_file.tolist() + labels = data_df.category_name.tolist() + return {"text": text, "image": image, "label": labels} + + dataset = { + lang: { + "train": _format_glami(data_tr), + "test": _format_glami(gb_test.get_group(lang)), + } + for lang, data_tr in gb_train + if lang in data_langs + } + + return dataset, labels, data_langs + + def binarize_labels(self, labels): + if self.dataset_name in ["rcv1-2", "jrc"]: + # labels are already binarized for rcv1-2 dataset + return labels + if hasattr(self, "mlb"): + return self.mlb.transform(labels) + else: + raise AttributeError("Label binarizer not found") + + def training(self): + lXtr = {} + lYtr = {} + for lang in self.data_langs: + text = self.dataset[lang]["train"]["text"] if self.is_textual else None + img = self.dataset[lang]["train"]["image"] if self.is_visual else None + labels = self.dataset[lang]["train"]["label"] + + lXtr[lang] = {"text": text, "image": img} + lYtr[lang] = self.binarize_labels(labels) + + return lXtr, lYtr + + def test(self): + lXte = {} + lYte = {} + for lang in self.data_langs: + text = self.dataset[lang]["test"]["text"] if self.is_textual else None + img = self.dataset[lang]["test"]["image"] if self.is_visual else None + labels = self.dataset[lang]["test"]["label"] + + lXte[lang] = {"text": text, "image": img} + lYte[lang] = self.binarize_labels(labels) + + return lXte, lYte + + def langs(self): + return self.data_langs + + def num_labels(self): + if self.dataset_name not in ["rcv1-2", "jrc"]: + return len(self.labels) + else: + return self.labels + + +if __name__ == "__main__": + import os + + GLAMI_DATAPATH = os.path.expanduser("~/datasets/GLAMI-1M-dataset") + RCV_DATAPATH = os.path.expanduser( + "~/datasets/rcv1-2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle" + ) + JRC_DATAPATH = os.path.expanduser( + "~/datasets/jrc/jrc_doclist_1958-2005vs2006_all_top300_noparallel_processed_run0.pickle" + ) + + print("Hello gFunDataset") + dataset = gFunDataset( + # dataset_dir=GLAMI_DATAPATH, + # dataset_dir=RCV_DATAPATH, + dataset_dir=JRC_DATAPATH, + data_langs=None, + is_textual=True, + is_visual=True, + is_multilabel=False, + labels=None, + nrows=13, + ) + lXtr, lYtr = dataset.training() + lXte, lYte = dataset.test() + exit(0) diff --git a/dataManager/glamiDataset.py b/dataManager/glamiDataset.py index 0d3ad50..46441b0 100644 --- a/dataManager/glamiDataset.py +++ b/dataManager/glamiDataset.py @@ -100,6 +100,26 @@ class GlamiDataset: self.nrows = nrows self.multilingual_dataset = {} + """ + self.multilingual_multimodal_dataset = { + lang: { + "text": {txt_data}, + "image": {img_data}, + } + } + TODO: if we decide to do this, we need to change both the + training (e.g. vectorizer should call "text") and also the + multilingual unimodal dataset (to include the field "text" only). + BUT this will be a pain when we split/shuffle the datasets. + I think it is better to have smt like this: + + self.ml_mm_dataset = { + "lang": (txt_data, img_data) + } + + but then also the unimodal dataset should have a "lang": (txt_data, _) value + """ + def num_labels(self): return len(self.labels) @@ -143,7 +163,7 @@ class GlamiDataset: def training(self): # TODO: tolist() or ??? lXtr = { - lang: (df.name + " " + df.description).tolist() + lang: ((df.name + " " + df.description).tolist(), df.image_file.tolist()) for lang, (df, _) in self.multilingual_dataset.items() } lYtr = { @@ -154,7 +174,7 @@ class GlamiDataset: def test(self): lXte = { - lang: (df.name + " " + df.description).tolist() + lang: ((df.name + " " + df.description).tolist(), df.image_file.tolist()) for lang, (_, df) in self.multilingual_dataset.items() } lYte = { diff --git a/dataManager/multiNewsDataset.py b/dataManager/multiNewsDataset.py index 9693937..87fba69 100644 --- a/dataManager/multiNewsDataset.py +++ b/dataManager/multiNewsDataset.py @@ -86,7 +86,7 @@ class MultiNewsDataset: lXtr = {} lYtr = {} for lang, data in self.lang_multiModalDataset.items(): - _data = [clean_text for _, clean_text, _, _ in data.data] + _data = [(clean_text, img) for _, clean_text, _, img in data.data] lXtr[lang] = _data lYtr = { lang: self.label_binarizer.transform(data.labels) diff --git a/dataManager/multilingualDatset.py b/dataManager/multilingualDatset.py index 7fd53f0..6ec2d78 100644 --- a/dataManager/multilingualDatset.py +++ b/dataManager/multilingualDatset.py @@ -64,7 +64,7 @@ class MultilingualDataset: def __init__(self, dataset_name): self.dataset_name = dataset_name self.multiling_dataset = {} - print(f"[Init Multilingual Dataset: {self.dataset_name}]") + # print(f"[Init Multilingual Dataset: {self.dataset_name}]") def add(self, lang, Xtr, Ytr, Xte, Yte, tr_ids=None, te_ids=None): self.multiling_dataset[lang] = ((Xtr, Ytr, tr_ids), (Xte, Yte, te_ids)) @@ -171,7 +171,7 @@ class MultilingualDataset: else: langs = sorted(self.multiling_dataset.keys()) return langs - + def num_labels(self): return self.num_categories() diff --git a/evaluation/evaluate.py b/evaluation/evaluate.py index 36363ca..28c1649 100644 --- a/evaluation/evaluate.py +++ b/evaluation/evaluate.py @@ -7,9 +7,8 @@ def evaluation_metrics(y, y_): if len(y.shape) == len(y_.shape) == 1 and len(np.unique(y)) > 2: # single-label raise NotImplementedError() # return f1_score(y,y_,average='macro'), f1_score(y,y_,average='micro') else: # the metrics I implemented assume multiclass multilabel classification as binary classifiers - # return macroF1(y, y_), microF1(y, y_), macroK(y, y_), microK(y, y_), macroP(y, y_), microP(y, y_), macroR(y, y_), microR(y, y_) - # return macroF1(y, y_), microF1(y, y_), macroAcc(y, y_), microAcc(y, y_), macroP(y, y_), microP(y, y_), macroR(y, y_), microR(y, y_), macroAcc(y, y_) return macroF1(y, y_), microF1(y, y_), macroK(y, y_), microK(y, y_) + # return macroF1(y, y_), microF1(y, y_), macroK(y, y_), macroAcc(y, y_) def evaluate(ly_true, ly_pred, metrics=evaluation_metrics, n_jobs=-1): diff --git a/evaluation/metrics.py b/evaluation/metrics.py index 19424e6..80cee37 100644 --- a/evaluation/metrics.py +++ b/evaluation/metrics.py @@ -235,3 +235,7 @@ def macroK(true_labels, predicted_labels): # true_labels and predicted_labels are two matrices in sklearn.preprocessing.MultiLabelBinarizer format def microK(true_labels, predicted_labels): return micro_average(true_labels, predicted_labels, K) + + +def macroAcc(true_labels, predicted_labels): + return macro_average(true_labels, predicted_labels, accuracy) diff --git a/gfun/generalizedFunnelling.py b/gfun/generalizedFunnelling.py index 4f2784d..3600999 100644 --- a/gfun/generalizedFunnelling.py +++ b/gfun/generalizedFunnelling.py @@ -1,17 +1,18 @@ import os import sys -sys.path.append(os.path.join(os.getcwd(), "gfun")) +# sys.path.append(os.path.join(os.getcwd(), "gfun")) import pickle import numpy as np -from vgfs.commons import TfidfVectorizerMultilingual, AttentionAggregator -from vgfs.learners.svms import MetaClassifier, get_learner -from vgfs.multilingualGen import MultilingualGen +from gfun.vgfs.commons import TfidfVectorizerMultilingual, AttentionAggregator +from gfun.vgfs.learners.svms import MetaClassifier, get_learner +from gfun.vgfs.multilingualGen import MultilingualGen from gfun.vgfs.textualTransformerGen import TextualTransformerGen -from vgfs.vanillaFun import VanillaFunGen -from vgfs.wceGen import WceGen +from gfun.vgfs.visualTransformerGen import VisualTransformerGen +from gfun.vgfs.vanillaFun import VanillaFunGen +from gfun.vgfs.wceGen import WceGen class GeneralizedFunnelling: @@ -20,7 +21,8 @@ class GeneralizedFunnelling: posterior, wce, multilingual, - transformer, + textual_transformer, + visual_transformer, langs, num_labels, embed_dir, @@ -31,7 +33,8 @@ class GeneralizedFunnelling: epochs, patience, evaluate_step, - transformer_name, + textual_transformer_name, + visual_transformer_name, optimc, device, load_trained, @@ -44,15 +47,16 @@ class GeneralizedFunnelling: self.posteriors_vgf = posterior self.wce_vgf = wce self.multilingual_vgf = multilingual - self.trasformer_vgf = transformer + self.trasformer_vgf = textual_transformer + self.visual_transformer_vgf = visual_transformer self.probabilistic = probabilistic self.num_labels = num_labels # ------------------------ self.langs = langs self.embed_dir = embed_dir self.cached = True - # Transformer VGF params ---------- - self.transformer_name = transformer_name + # Textual Transformer VGF params ---------- + self.textaul_transformer_name = textual_transformer_name self.epochs = epochs self.lr_transformer = lr self.batch_size_transformer = batch_size @@ -61,6 +65,8 @@ class GeneralizedFunnelling: self.patience = patience self.evaluate_step = evaluate_step self.device = device + # Visual Transformer VGF params ---------- + self.visual_transformer_name = visual_transformer_name # Metaclassifier params ------------ self.optimc = optimc # ------------------- @@ -78,7 +84,7 @@ class GeneralizedFunnelling: self._init() def _init(self): - print("[Init GeneralizedFunnelling]") + print("\n[Init GeneralizedFunnelling]") assert not ( self.aggfunc == "mean" and self.probabilistic is False ), "When using averaging aggreagation function probabilistic must be True" @@ -139,20 +145,35 @@ class GeneralizedFunnelling: if self.trasformer_vgf: transformer_vgf = TextualTransformerGen( dataset_name=self.dataset_name, - model_name=self.transformer_name, + model_name=self.textaul_transformer_name, lr=self.lr_transformer, epochs=self.epochs, batch_size=self.batch_size_transformer, max_length=self.max_length, - device=self.device, print_steps=50, probabilistic=self.probabilistic, evaluate_step=self.evaluate_step, verbose=True, patience=self.patience, + device=self.device, ) self.first_tier_learners.append(transformer_vgf) + if self.visual_transformer_vgf: + visual_trasformer_vgf = VisualTransformerGen( + dataset_name=self.dataset_name, + model_name="vit", + lr=1e-5, # self.lr_visual_transformer, + epochs=self.epochs, + batch_size=32, # self.batch_size_visual_transformer, + # batch_size_eval=128, + probabilistic=self.probabilistic, + evaluate_step=self.evaluate_step, + patience=self.patience, + device=self.device, + ) + self.first_tier_learners.append(visual_trasformer_vgf) + if "attn" in self.aggfunc: attn_stacking = self.aggfunc.split("_")[1] self.attn_aggregator = AttentionAggregator( @@ -189,15 +210,18 @@ class GeneralizedFunnelling: vgf.vectorizer = self.vectorizer def fit(self, lX, lY): - print("[Fitting GeneralizedFunnelling]") + print("\n[Fitting GeneralizedFunnelling]") if self.load_trained is not None: print( "- loaded first tier learners!" if self.load_meta is False else "- loaded trained model!" ) + """ + if we are only loading the first tier, we need to + transform the training data to train the meta-classifier + """ if self.load_first_tier is True and self.load_meta is False: - # TODO: clean up this code here projections = [] for vgf in self.first_tier_learners: l_posteriors = vgf.transform(lX) @@ -403,7 +427,7 @@ def get_unique_id(dataset_name, posterior, multilingual, wce, transformer, aggfu from datetime import datetime now = datetime.now().strftime("%y%m%d") - model_id = dataset_name + model_id = f"{dataset_name}_" model_id += "p" if posterior else "" model_id += "m" if multilingual else "" model_id += "w" if wce else "" diff --git a/gfun/vgfs/commons.py b/gfun/vgfs/commons.py index 3231fcd..92a481a 100644 --- a/gfun/vgfs/commons.py +++ b/gfun/vgfs/commons.py @@ -4,17 +4,17 @@ from collections import defaultdict import numpy as np import torch import torch.nn as nn -from torch.utils.data import DataLoader, Dataset from sklearn.decomposition import TruncatedSVD from sklearn.feature_extraction.text import TfidfVectorizer +from sklearn.model_selection import train_test_split from sklearn.preprocessing import normalize from torch.optim import AdamW -from transformers.modeling_outputs import SequenceClassifierOutput -from sklearn.model_selection import train_test_split +from torch.utils.data import DataLoader, Dataset +from transformers.modeling_outputs import ModelOutput from evaluation.evaluate import evaluate, log_eval -PRINT_ON_EPOCH = 10 +PRINT_ON_EPOCH = 1 def _normalize(lX, l2=True): @@ -78,12 +78,12 @@ class TfidfVectorizerMultilingual: def fit(self, lX, ly=None): self.langs = sorted(lX.keys()) self.vectorizer = { - l: TfidfVectorizer(**self.kwargs).fit(lX[l]) for l in self.langs + l: TfidfVectorizer(**self.kwargs).fit(lX[l]["text"]) for l in self.langs } return self def transform(self, lX): - return {l: self.vectorizer[l].transform(lX[l]) for l in self.langs} + return {l: self.vectorizer[l].transform(lX[l]["text"]) for l in self.langs} def fit_transform(self, lX, ly=None): return self.fit(lX, ly).transform(lX) @@ -123,6 +123,7 @@ class Trainer: self.print_steps = print_steps self.experiment_name = experiment_name self.patience = patience + self.print_eval = evaluate_step self.earlystopping = EarlyStopping( patience=patience, checkpoint_path=checkpoint_path, @@ -144,12 +145,15 @@ class Trainer: - train batch size: {train_dataloader.batch_size} - eval batch size: {eval_dataloader.batch_size} - max len: {train_dataloader.dataset.X.shape[-1]} - - patience: {self.earlystopping.patience}\n""" + - patience: {self.earlystopping.patience} + - evaluate every: {self.evaluate_steps} + - print eval every: {self.print_eval} + - print train steps: {self.print_steps}\n""" ) for epoch in range(epochs): self.train_epoch(train_dataloader, epoch) if (epoch + 1) % self.evaluate_steps == 0: - print_eval = (epoch + 1) % 25 == 0 + print_eval = (epoch + 1) % self.print_eval == 0 metric_watcher = self.evaluate(eval_dataloader, print_eval=print_eval) stop = self.earlystopping(metric_watcher, self.model, epoch + 1) if stop: @@ -160,7 +164,7 @@ class Trainer: self.device ) break - print(f"\n- last swipe on eval set") + print(f"- last swipe on eval set") self.train_epoch(eval_dataloader, epoch=0) self.earlystopping.save_model(self.model) return self.model @@ -170,7 +174,7 @@ class Trainer: for b_idx, (x, y, lang) in enumerate(dataloader): self.optimizer.zero_grad() y_hat = self.model(x.to(self.device)) - if isinstance(y_hat, SequenceClassifierOutput): + if isinstance(y_hat, ModelOutput): loss = self.loss_fn(y_hat.logits, y.to(self.device)) else: loss = self.loss_fn(y_hat, y.to(self.device)) @@ -189,7 +193,7 @@ class Trainer: for b_idx, (x, y, lang) in enumerate(dataloader): y_hat = self.model(x.to(self.device)) - if isinstance(y_hat, SequenceClassifierOutput): + if isinstance(y_hat, ModelOutput): loss = self.loss_fn(y_hat.logits, y.to(self.device)) predictions = predict(y_hat.logits, classification_type="multilabel") else: @@ -272,6 +276,10 @@ class AttentionModule(nn.Module): self.linear = nn.Linear(embed_dim, out_dim) self.sigmoid = nn.Sigmoid() + def init_weights(self, mode="mean"): + # TODO: add init function of the attention module: either all weights are positive or set to 1/num_classes + raise NotImplementedError + def __call__(self, X): out, attn_weights = self.attn(query=X, key=X, value=X) # out = self.layer_norm(out) diff --git a/gfun/vgfs/multilingualGen.py b/gfun/vgfs/multilingualGen.py index e4a8386..231fb18 100644 --- a/gfun/vgfs/multilingualGen.py +++ b/gfun/vgfs/multilingualGen.py @@ -4,9 +4,9 @@ import torch import numpy as np from torchtext.vocab import Vectors from joblib import Parallel, delayed -from vgfs.viewGen import ViewGen -from vgfs.commons import _normalize, XdotM -from vgfs.learners.svms import FeatureSet2Posteriors +from gfun.vgfs.viewGen import ViewGen +from gfun.vgfs.commons import _normalize, XdotM +from gfun.vgfs.learners.svms import FeatureSet2Posteriors class MultilingualGen(ViewGen): diff --git a/gfun/vgfs/textualTransformerGen.py b/gfun/vgfs/textualTransformerGen.py index 9d86b40..8a525c6 100644 --- a/gfun/vgfs/textualTransformerGen.py +++ b/gfun/vgfs/textualTransformerGen.py @@ -7,14 +7,14 @@ from collections import defaultdict import numpy as np import torch import transformers -from sklearn.model_selection import train_test_split -from torch.optim import AdamW -from torch.utils.data import DataLoader, Dataset + +# from sklearn.model_selection import train_test_split +# from torch.optim import AdamW +from torch.utils.data import Dataset from transformers import AutoModelForSequenceClassification, AutoTokenizer -from vgfs.learners.svms import FeatureSet2Posteriors -from vgfs.viewGen import ViewGen -from vgfs.transformerGen import TransformerGen -from vgfs.commons import Trainer, predict +from gfun.vgfs.commons import Trainer +from gfun.vgfs.transformerGen import TransformerGen +from gfun.vgfs.viewGen import ViewGen transformers.logging.set_verbosity_error() @@ -104,7 +104,7 @@ class TextualTransformerGen(ViewGen, TransformerGen): ) tr_lX, tr_lY, val_lX, val_lY = self.get_train_val_data( - lX, lY, split=0.2, seed=42 + lX, lY, split=0.2, seed=42, modality="text" ) tra_dataloader = self.build_dataloader( @@ -156,6 +156,8 @@ class TextualTransformerGen(ViewGen, TransformerGen): return self def transform(self, lX): + # forcing to only text modality + lX = {lang: data["text"] for lang, data in lX.items()} _embeds = [] l_embeds = defaultdict(list) @@ -196,8 +198,8 @@ class TextualTransformerGen(ViewGen, TransformerGen): from os import makedirs from os.path import join - vgf_name = "transformerGen" - _basedir = join("models", "vgfs", "transformer") + vgf_name = "textualTransformerGen" + _basedir = join("models", "vgfs", "textual_transformer") makedirs(_basedir, exist_ok=True) _path = join(_basedir, f"{vgf_name}_{model_id}.pkl") with open(_path, "wb") as f: diff --git a/gfun/vgfs/transformerGen.py b/gfun/vgfs/transformerGen.py index 61396f3..f7f4e41 100644 --- a/gfun/vgfs/transformerGen.py +++ b/gfun/vgfs/transformerGen.py @@ -1,6 +1,6 @@ from sklearn.model_selection import train_test_split from torch.utils.data import Dataset, DataLoader -from vgfs.learners.svms import FeatureSet2Posteriors +from gfun.vgfs.learners.svms import FeatureSet2Posteriors class TransformerGen: @@ -67,16 +67,26 @@ class TransformerGen: split="train", shuffle=False, ): - l_tokenized = {lang: processor_fn(data) for lang, data in lX.items()} - self.datasets[split] = torchDataset(l_tokenized, lY, split=split) - return DataLoader(self.datasets[split], batch_size=batch_size, shuffle=shuffle) + l_processed = {lang: processor_fn(lX[lang]) for lang in lX.keys()} + self.datasets[split] = torchDataset(l_processed, lY, split=split) + return DataLoader( + self.datasets[split], + batch_size=batch_size, + shuffle=shuffle, + # collate_fn=processor_fn, + ) - def get_train_val_data(self, lX, lY, split=0.2, seed=42): + def get_train_val_data(self, lX, lY, split=0.2, seed=42, modality="text"): + assert modality in ["text", "image"], "modality must be either text or image" tr_lX, tr_lY, val_lX, val_lY = {}, {}, {}, {} for lang in lX.keys(): tr_X, val_X, tr_Y, val_Y = train_test_split( - lX[lang], lY[lang], test_size=split, random_state=seed, shuffle=False + lX[lang][modality], + lY[lang], + test_size=split, + random_state=seed, + shuffle=False, ) tr_lX[lang] = tr_X tr_lY[lang] = tr_Y diff --git a/gfun/vgfs/vanillaFun.py b/gfun/vgfs/vanillaFun.py index 3416766..551e5c9 100644 --- a/gfun/vgfs/vanillaFun.py +++ b/gfun/vgfs/vanillaFun.py @@ -1,6 +1,6 @@ -from vgfs.viewGen import ViewGen -from vgfs.learners.svms import NaivePolylingualClassifier -from vgfs.commons import _normalize +from gfun.vgfs.viewGen import ViewGen +from gfun.vgfs.learners.svms import NaivePolylingualClassifier +from gfun.vgfs.commons import _normalize class VanillaFunGen(ViewGen): diff --git a/gfun/vgfs/visualTransformerGen.py b/gfun/vgfs/visualTransformerGen.py index 1f9915d..7692c97 100644 --- a/gfun/vgfs/visualTransformerGen.py +++ b/gfun/vgfs/visualTransformerGen.py @@ -1,16 +1,15 @@ -import sys, os - -sys.path.append(os.getcwd()) +from collections import defaultdict +import numpy as np import torch import transformers -from gfun.vgfs.viewGen import ViewGen -from transformers import AutoImageProcessor -from torch.utils.data import DataLoader, Dataset -from torchvision.transforms import RandomResizedCrop, Compose, Normalize, ToTensor -from gfun.vgfs.commons import Trainer, predict +from PIL import Image +from torch.utils.data import Dataset +from transformers import AutoImageProcessor, AutoModelForImageClassification + +from gfun.vgfs.commons import Trainer from gfun.vgfs.transformerGen import TransformerGen -from transformers import AutoModelForImageClassification +from gfun.vgfs.viewGen import ViewGen transformers.logging.set_verbosity_error() @@ -26,6 +25,7 @@ class VisualTransformerGen(ViewGen, TransformerGen): batch_size_eval=128, evaluate_step=10, device="cpu", + probabilistic=False, patience=5, ): super().__init__( @@ -38,6 +38,7 @@ class VisualTransformerGen(ViewGen, TransformerGen): device=device, evaluate_step=evaluate_step, patience=patience, + probabilistic=probabilistic, ) self.fitted = False print( @@ -52,44 +53,28 @@ class VisualTransformerGen(ViewGen, TransformerGen): def init_model(self, model_name, num_labels): model = AutoModelForImageClassification.from_pretrained( - model_name, num_labels=num_labels + model_name, num_labels=num_labels, output_hidden_states=True ) image_processor = AutoImageProcessor.from_pretrained(model_name) - transforms = self.init_preprocessor(image_processor) - return model, image_processor, transforms - - def init_preprocessor(self, image_processor): - normalize = Normalize( - mean=image_processor.image_mean, std=image_processor.image_std - ) - size = ( - image_processor.size["shortest_edge"] - if "shortest_edge" in image_processor.size - else (image_processor.size["height"], image_processor.size["width"]) - ) - # these are the transformations that we are applying to the images - transforms = Compose([RandomResizedCrop(size), ToTensor(), normalize]) - return transforms - - def preprocess(self, images, transforms): - processed = transforms(img.convert("RGB") for img in images) - return processed + return model, image_processor def process_all(self, X): - # TODO: every element in X is a tuple (doc_id, clean_text, text, Pil.Image), so we're taking just the last element for processing - processed = torch.stack([self.transforms(img[-1]) for img in X]) - return processed + # TODO: should be moved as a collate_fn to avoid this overhead + processed = self.image_preprocessor( + [Image.open(img).convert("RGB") for img in X], return_tensors="pt" + ) + return processed["pixel_values"] def fit(self, lX, lY): print("- fitting Visual Transformer View Generating Function") _l = list(lX.keys())[0] self.num_labels = lY[_l].shape[-1] - self.model, self.image_preprocessor, self.transforms = self.init_model( + self.model, self.image_preprocessor = self.init_model( self._validate_model_name(self.model_name), num_labels=self.num_labels ) tr_lX, tr_lY, val_lX, val_lY = self.get_train_val_data( - lX, lY, split=0.2, seed=42 + lX, lY, split=0.2, seed=42, modality="image" ) tra_dataloader = self.build_dataloader( @@ -135,17 +120,64 @@ class VisualTransformerGen(ViewGen, TransformerGen): if self.probabilistic: self.feature2posterior_projector.fit(self.transform(lX), lY) + self.fitted = True + + return self + def transform(self, lX): - raise NotImplementedError + # forcing to only image modality + lX = {lang: data["image"] for lang, data in lX.items()} + _embeds = [] + l_embeds = defaultdict(list) + + dataloader = self.build_dataloader( + lX, + lY=None, + processor_fn=self.process_all, + torchDataset=MultimodalDatasetTorch, + batch_size=self.batch_size_eval, + split="whole", + shuffle=False, + ) + + self.model.eval() + with torch.no_grad(): + for input_ids, lang in dataloader: + input_ids = input_ids.to(self.device) + out = self.model(input_ids).hidden_states[-1] + batch_embeddings = out[:, 0, :].cpu().numpy() + _embeds.append((batch_embeddings, lang)) + + for embed, lang in _embeds: + for sample_embed, sample_lang in zip(embed, lang): + l_embeds[sample_lang].append(sample_embed) + + if self.probabilistic and self.fitted: + l_embeds = self.feature2posterior_projector.transform(l_embeds) + elif not self.probabilistic and self.fitted: + l_embeds = {lang: np.array(preds) for lang, preds in l_embeds.items()} + + return l_embeds def fit_transform(self, lX, lY): - raise NotImplementedError + return self.fit(lX, lY).transform(lX) def save_vgf(self, model_id): - raise NotImplementedError + import pickle + from os import makedirs + from os.path import join - def save_vgf(self, model_id): - raise NotImplementedError + vgf_name = "visualTransformerGen" + _basedir = join("models", "vgfs", "visual_transformer") + makedirs(_basedir, exist_ok=True) + _path = join(_basedir, f"{vgf_name}_{model_id}.pkl") + with open(_path, "wb") as f: + pickle.dump(self, f) + return self + + def __str__(self): + str = f"[Visual Transformer VGF (v)]\n- model_name: {self.model_name}\n- batch_size: {self.batch_size}\n- batch_size_eval: {self.batch_size_eval}\n- lr: {self.lr}\n- epochs: {self.epochs}\n- device: {self.device}\n- print_steps: {self.print_steps}\n- evaluate_step: {self.evaluate_step}\n- patience: {self.patience}\n- probabilistic: {self.probabilistic}\n" + return str class MultimodalDatasetTorch(Dataset): @@ -169,7 +201,6 @@ class MultimodalDatasetTorch(Dataset): ], [], ) - # print(f"- lX has shape: {self.X.shape}\n- lY has shape: {self.Y.shape}") def __len__(self): return len(self.X) @@ -182,15 +213,28 @@ class MultimodalDatasetTorch(Dataset): if __name__ == "__main__": from os.path import expanduser - from dataManager.multiNewsDataset import MultiNewsDataset - _dataset_path_hardcoded = "~/datasets/MultiNews/20110730/" + from dataManager.gFunDataset import gFunDataset - dataset = MultiNewsDataset(expanduser(_dataset_path_hardcoded), debug=True) - lXtr, lYtr = dataset.training() + GLAMI_DATAPATH = expanduser("~/datasets/GLAMI-1M-dataset") + dataset = gFunDataset( + dataset_dir=GLAMI_DATAPATH, + is_textual=True, + is_visual=True, + is_multilabel=False, + nrows=50, + ) vg = VisualTransformerGen( - model_name="vit", device="cuda", epochs=1000, evaluate_step=10, patience=100 + dataset_name=dataset.dataset_name, + model_name="vit", + device="cuda", + epochs=5, + evaluate_step=10, + patience=10, + probabilistic=True, ) lX, lY = dataset.training() vg.fit(lX, lY) + out = vg.transform(lX) + exit(0) diff --git a/gfun/vgfs/wceGen.py b/gfun/vgfs/wceGen.py index 6c5e471..a7889df 100644 --- a/gfun/vgfs/wceGen.py +++ b/gfun/vgfs/wceGen.py @@ -1,7 +1,7 @@ import numpy as np from joblib import Parallel, delayed -from vgfs.commons import XdotM, _normalize -from vgfs.viewGen import ViewGen +from gfun.vgfs.commons import XdotM, _normalize +from gfun.vgfs.viewGen import ViewGen class WceGen(ViewGen): diff --git a/main.py b/main.py index fa85836..afa579c 100644 --- a/main.py +++ b/main.py @@ -7,6 +7,7 @@ from dataManager.amazonDataset import AmazonDataset from dataManager.multilingualDatset import MultilingualDataset from dataManager.multiNewsDataset import MultiNewsDataset from dataManager.glamiDataset import GlamiDataset +from dataManager.gFunDataset import gFunDataset from evaluation.evaluate import evaluate, log_eval from gfun.generalizedFunnelling import GeneralizedFunnelling @@ -44,11 +45,15 @@ def get_dataset(datasetname): GLAMI_DATAPATH = expanduser("~/datasets/GLAMI-1M-dataset") if datasetname == "multinews": + # TODO: convert to gFunDataset + raise NotImplementedError dataset = MultiNewsDataset( expanduser(MULTINEWS_DATAPATH), excluded_langs=["ar", "pe", "pl", "tr", "ua"], ) elif datasetname == "amazon": + # TODO: convert to gFunDataset + raise NotImplementedError dataset = AmazonDataset( domains=args.domains, nrows=args.nrows, @@ -56,12 +61,21 @@ def get_dataset(datasetname): max_labels=args.max_labels, ) elif datasetname == "rcv1-2": - dataset = MultilingualDataset(dataset_name="rcv1-2").load(RCV_DATAPATH) - if args.nrows is not None: - dataset.reduce_data(langs=["en", "it", "fr"], maxn=args.nrows) + dataset = gFunDataset( + dataset_dir=RCV_DATAPATH, + is_textual=True, + is_visual=False, + is_multilabel=True, + nrows=args.nrows, + ) elif datasetname == "glami": - dataset = GlamiDataset(dataset_dir=GLAMI_DATAPATH, nrows=args.nrows) - dataset.build_dataset() + dataset = gFunDataset( + dataset_dir=GLAMI_DATAPATH, + is_textual=True, + is_visual=True, + is_multilabel=False, + nrows=args.nrows, + ) else: raise NotImplementedError return dataset @@ -73,6 +87,7 @@ def main(args): isinstance(dataset, MultilingualDataset) or isinstance(dataset, MultiNewsDataset) or isinstance(dataset, GlamiDataset) + or isinstance(dataset, gFunDataset) ): lX, lY = dataset.training() lX_te, lY_te = dataset.test() @@ -89,7 +104,8 @@ def main(args): args.wce, args.multilingual, args.multilingual, - args.transformer, + args.textual_transformer, + args.visual_transformer, ] ), "At least one of VGF must be True" @@ -106,8 +122,8 @@ def main(args): # WCE VGF params ---------------------- wce=args.wce, # Transformer VGF params -------------- - transformer=args.transformer, - transformer_name=args.transformer_name, + textual_transformer=args.textual_transformer, + textual_transformer_name=args.transformer_name, batch_size=args.batch_size, epochs=args.epochs, lr=args.lr, @@ -115,6 +131,15 @@ def main(args): patience=args.patience, evaluate_step=args.evaluate_step, device="cuda", + # Visual Transformer VGF params -------------- + visual_transformer=args.visual_transformer, + visual_transformer_name=args.visual_transformer_name, + # batch_size=args.batch_size, + # epochs=args.epochs, + # lr=args.lr, + # patience=args.patience, + # evaluate_step=args.evaluate_step, + # device="cuda", # General params ---------------------- probabilistic=args.features, aggfunc=args.aggfunc, @@ -152,7 +177,7 @@ if __name__ == "__main__": parser.add_argument("--meta", action="store_true") parser.add_argument("--nosave", action="store_true") # Dataset parameters ------------------- - parser.add_argument("-d", "--dataset", type=str, default="multinews") + parser.add_argument("-d", "--dataset", type=str, default="rcv1-2") parser.add_argument("--domains", type=str, default="all") parser.add_argument("--nrows", type=int, default=None) parser.add_argument("--min_count", type=int, default=10) @@ -161,7 +186,8 @@ if __name__ == "__main__": parser.add_argument("-p", "--posteriors", action="store_true") parser.add_argument("-m", "--multilingual", action="store_true") parser.add_argument("-w", "--wce", action="store_true") - parser.add_argument("-t", "--transformer", action="store_true") + parser.add_argument("-t", "--textual_transformer", action="store_true") + parser.add_argument("-v", "--visual_transformer", action="store_true") parser.add_argument("--n_jobs", type=int, default=-1) parser.add_argument("--optimc", action="store_true") parser.add_argument("--features", action="store_false") @@ -169,11 +195,13 @@ if __name__ == "__main__": # transformer parameters --------------- parser.add_argument("--transformer_name", type=str, default="mbert") parser.add_argument("--batch_size", type=int, default=32) - parser.add_argument("--epochs", type=int, default=1000) + parser.add_argument("--epochs", type=int, default=100) parser.add_argument("--lr", type=float, default=1e-5) - parser.add_argument("--max_length", type=int, default=512) + parser.add_argument("--max_length", type=int, default=128) parser.add_argument("--patience", type=int, default=5) parser.add_argument("--evaluate_step", type=int, default=10) + # Visual Transformer parameters -------------- + parser.add_argument("--visual_transformer_name", type=str, default="vit") args = parser.parse_args()