From dba2ed9c9c816c57d2d6ae2d5febdc2733e6634e Mon Sep 17 00:00:00 2001 From: andreapdr Date: Thu, 9 Feb 2023 16:55:06 +0100 Subject: [PATCH] Visual Transformer VGF --- dataManager/multiNewsDataset.py | 13 ++- gfun/vgfs/commons.py | 180 ++++++++++++++++++++++++++--- gfun/vgfs/textualTransformerGen.py | 10 +- gfun/vgfs/transformerGen.py | 41 +++++++ gfun/vgfs/visualGen.py | 18 --- gfun/vgfs/visualTransformerGen.py | 175 ++++++++++++++++++++++++++++ 6 files changed, 395 insertions(+), 42 deletions(-) create mode 100644 gfun/vgfs/transformerGen.py delete mode 100644 gfun/vgfs/visualGen.py create mode 100644 gfun/vgfs/visualTransformerGen.py diff --git a/dataManager/multiNewsDataset.py b/dataManager/multiNewsDataset.py index 2958aa6..749403a 100644 --- a/dataManager/multiNewsDataset.py +++ b/dataManager/multiNewsDataset.py @@ -49,7 +49,7 @@ class MultiNewsDataset: from os import listdir if self.debug: - return ["it"] + return ["it", "en"] return tuple(sorted([folder for folder in listdir(self.data_dir)])) @@ -67,7 +67,7 @@ class MultiNewsDataset: def _count_lang_labels(self, labels): lang_labels = set() for l in labels: - lang_labels.update(l[-1]) + lang_labels.update(l) return len(lang_labels) def export_to_torch_dataset(self, tokenizer_id): @@ -125,11 +125,14 @@ class MultiModalDataset: with open(join(self.data_dir, news_folder, fname_doc)) as f: html_doc = f.read() index_path = join(self.data_dir, news_folder, "index.html") - if ".jpg" not in listdir(join(self.data_dir, news_folder)): + if not any( + File.endswith(".jpg") + for File in listdir(join(self.data_dir, news_folder)) + ): img_link, img = self.get_images(index_path) self.save_img(join(self.data_dir, news_folder, "img.jpg"), img) - else: - img = Image.open(join(self.data_dir, news_folder, "img.jpg")) + # TODO: convert img to PIL image + img = Image.open(join(self.data_dir, news_folder, "img.jpg")) clean_doc, doc_labels = self.preprocess_html(html_doc) data.append((fname_doc, clean_doc, html_doc, img)) labels.append(doc_labels) diff --git a/gfun/vgfs/commons.py b/gfun/vgfs/commons.py index 74a57ad..96a485d 100644 --- a/gfun/vgfs/commons.py +++ b/gfun/vgfs/commons.py @@ -1,7 +1,14 @@ -from sklearn.preprocessing import normalize -from sklearn.feature_extraction.text import TfidfVectorizer -from sklearn.decomposition import TruncatedSVD +import os +from collections import defaultdict + import numpy as np +import torch +from sklearn.decomposition import TruncatedSVD +from sklearn.feature_extraction.text import TfidfVectorizer +from sklearn.preprocessing import normalize +from torch.optim import AdamW + +from evaluation.evaluate import evaluate, log_eval def _normalize(lX, l2=True): @@ -30,6 +37,34 @@ def remove_pc(X, npc=1): return XX +def compute_pc(X, npc=1): + """ + Compute the principal components. + :param X: X[i,:] is a data point + :param npc: number of principal components to remove + :return: component_[i,:] is the i-th pc + """ + if isinstance(X, np.matrix): + X = np.asarray(X) + svd = TruncatedSVD(n_components=npc, n_iter=7, random_state=0) + svd.fit(X) + return svd.components_ + + +def predict(logits, classification_type="multilabel"): + """ + Converts soft precictions to hard predictions [0,1] + """ + if classification_type == "multilabel": + prediction = torch.sigmoid(logits) > 0.5 + elif classification_type == "singlelabel": + prediction = torch.argmax(logits, dim=1).view(-1, 1) + else: + print("unknown classification type") + + return prediction.detach().cpu().numpy() + + class TfidfVectorizerMultilingual: def __init__(self, **kwargs): self.kwargs = kwargs @@ -60,15 +95,130 @@ class TfidfVectorizerMultilingual: return self.vectorizer[l].build_analyzer() -def compute_pc(X, npc=1): - """ - Compute the principal components. - :param X: X[i,:] is a data point - :param npc: number of principal components to remove - :return: component_[i,:] is the i-th pc - """ - if isinstance(X, np.matrix): - X = np.asarray(X) - svd = TruncatedSVD(n_components=npc, n_iter=7, random_state=0) - svd.fit(X) - return svd.components_ +class Trainer: + def __init__( + self, + model, + optimizer_name, + device, + loss_fn, + lr, + print_steps, + evaluate_step, + patience, + experiment_name, + ): + self.device = device + self.model = model.to(device) + self.optimizer = self.init_optimizer(optimizer_name, lr) + self.evaluate_steps = evaluate_step + self.loss_fn = loss_fn.to(device) + self.print_steps = print_steps + self.earlystopping = EarlyStopping( + patience=patience, + checkpoint_path="models/vgfs/transformers/", + verbose=True, + experiment_name=experiment_name, + ) + + def init_optimizer(self, optimizer_name, lr): + if optimizer_name.lower() == "adamw": + return AdamW(self.model.parameters(), lr=lr) + else: + raise ValueError(f"Optimizer {optimizer_name} not supported") + + def train(self, train_dataloader, eval_dataloader, epochs=10): + print( + f"""- Training params: + - epochs: {epochs} + - learning rate: {self.optimizer.defaults['lr']} + - train batch size: {train_dataloader.batch_size} + - eval batch size: {eval_dataloader.batch_size} + - max len: {train_dataloader.dataset.X.shape[-1]}\n""", + ) + for epoch in range(epochs): + self.train_epoch(train_dataloader, epoch) + if (epoch + 1) % self.evaluate_steps == 0: + metric_watcher = self.evaluate(eval_dataloader) + stop = self.earlystopping(metric_watcher, self.model, epoch + 1) + if stop: + break + return self.model + + def train_epoch(self, dataloader, epoch): + self.model.train() + for b_idx, (x, y, lang) in enumerate(dataloader): + self.optimizer.zero_grad() + y_hat = self.model(x.to(self.device)) + loss = self.loss_fn(y_hat.logits, y.to(self.device)) + loss.backward() + self.optimizer.step() + if b_idx % self.print_steps == 0: + print(f"Epoch: {epoch+1} Step: {b_idx+1} Loss: {loss:.4f}") + return self + + def evaluate(self, dataloader): + self.model.eval() + + lY = defaultdict(list) + lY_hat = defaultdict(list) + + for b_idx, (x, y, lang) in enumerate(dataloader): + y_hat = self.model(x.to(self.device)) + loss = self.loss_fn(y_hat.logits, y.to(self.device)) + predictions = predict(y_hat.logits, classification_type="multilabel") + + for l, _true, _pred in zip(lang, y, predictions): + lY[l].append(_true.detach().cpu().numpy()) + lY_hat[l].append(_pred) + + for lang in lY: + lY[lang] = np.vstack(lY[lang]) + lY_hat[lang] = np.vstack(lY_hat[lang]) + + l_eval = evaluate(lY, lY_hat) + average_metrics = log_eval(l_eval, phase="validation") + return average_metrics[0] # macro-F1 + + +class EarlyStopping: + def __init__( + self, + patience=5, + min_delta=0, + verbose=True, + checkpoint_path="checkpoint.pt", + experiment_name="experiment", + ): + self.patience = patience + self.min_delta = min_delta + self.counter = 0 + self.best_score = 0 + self.best_epoch = None + self.verbose = verbose + self.checkpoint_path = checkpoint_path + self.experiment_name = experiment_name + + def __call__(self, validation, model, epoch): + if validation > self.best_score: + print( + f"- earlystopping: Validation score improved from {self.best_score:.3f} to {validation:.3f}" + ) + self.best_score = validation + self.counter = 0 + # self.save_model(model) + elif validation < (self.best_score + self.min_delta): + self.counter += 1 + print( + f"- earlystopping: Validation score decreased from {self.best_score:.3f} to {validation:.3f}, current patience: {self.patience - self.counter}" + ) + if self.counter >= self.patience: + if self.verbose: + print(f"- earlystopping: Early stopping at epoch {epoch}") + return True + + def save_model(self, model): + _checkpoint_dir = os.path.join(self.checkpoint_path, self.experiment_name) + print(f"- saving model to {_checkpoint_dir}") + os.makedirs(_checkpoint_dir, exist_ok=True) + model.save_pretrained(_checkpoint_dir) diff --git a/gfun/vgfs/textualTransformerGen.py b/gfun/vgfs/textualTransformerGen.py index eb134bd..c705160 100644 --- a/gfun/vgfs/textualTransformerGen.py +++ b/gfun/vgfs/textualTransformerGen.py @@ -12,6 +12,7 @@ from torch.optim import AdamW from torch.utils.data import DataLoader, Dataset from transformers import AutoModelForSequenceClassification, AutoTokenizer from vgfs.learners.svms import FeatureSet2Posteriors +from vgfs.viewGen import ViewGen from evaluation.evaluate import evaluate, log_eval @@ -19,9 +20,10 @@ transformers.logging.set_verbosity_error() # TODO: add support to loggers +# TODO: multiple inheritance - maybe define a superclass for TransformerGenerator, whether it is a Textual or a Visual one, implementing dataset creation functions -class TextualTransformerGen: +class TextualTransformerGen(ViewGen): def __init__( self, model_name, @@ -105,7 +107,7 @@ class TextualTransformerGen: return tr_lX, tr_lY, val_lX, val_lY - def build_dataloader(self, lX, lY, batch_size, split="train", shuffle=True): + def build_dataloader(self, lX, lY, batch_size, split="train", shuffle=False): l_tokenized = {lang: self._tokenize(data) for lang, data in lX.items()} self.datasets[split] = MultilingualDatasetTorch(l_tokenized, lY, split=split) return DataLoader(self.datasets[split], batch_size=batch_size, shuffle=shuffle) @@ -122,7 +124,7 @@ class TextualTransformerGen: def fit(self, lX, lY): if self.fitted: return self - print("- fitting Transformer View Generating Function") + print("- fitting Textual Transformer View Generating Function") _l = list(lX.keys())[0] self.num_labels = lY[_l].shape[-1] self.model, self.tokenizer = self.init_model( @@ -196,8 +198,8 @@ class TextualTransformerGen: def save_vgf(self, model_id): import pickle - from os.path import join from os import makedirs + from os.path import join vgf_name = "transformerGen" _basedir = join("models", "vgfs", "transformer") diff --git a/gfun/vgfs/transformerGen.py b/gfun/vgfs/transformerGen.py new file mode 100644 index 0000000..3dc3814 --- /dev/null +++ b/gfun/vgfs/transformerGen.py @@ -0,0 +1,41 @@ +from sklearn.model_selection import train_test_split +from torch.utils.data import Dataset, DataLoader + + +class TransformerGen: + """Base class for all transformers. It implements the basic methods for + the creation of the datasets, datalaoders and the train-val split method. + It is designed to be used with MultilingualDataset in the + form of dictioanries {lang: data} + """ + + def __init__(self): + self.datasets = {} + + def build_dataloader( + self, + lX, + lY, + torchDataset, + processor_fn, + batch_size, + split="train", + shuffle=False, + ): + l_tokenized = {lang: processor_fn(data) for lang, data in lX.items()} + self.datasets[split] = torchDataset(l_tokenized, lY, split=split) + return DataLoader(self.datasets[split], batch_size=batch_size, shuffle=shuffle) + + def get_train_val_data(self, lX, lY, split=0.2, seed=42): + tr_lX, tr_lY, val_lX, val_lY = {}, {}, {}, {} + + for lang in lX.keys(): + tr_X, val_X, tr_Y, val_Y = train_test_split( + lX[lang], lY[lang], test_size=split, random_state=seed, shuffle=False + ) + tr_lX[lang] = tr_X + tr_lY[lang] = tr_Y + val_lX[lang] = val_X + val_lY[lang] = val_Y + + return tr_lX, tr_lY, val_lX, val_lY diff --git a/gfun/vgfs/visualGen.py b/gfun/vgfs/visualGen.py deleted file mode 100644 index 94ff96c..0000000 --- a/gfun/vgfs/visualGen.py +++ /dev/null @@ -1,18 +0,0 @@ -from vgfs.viewGen import ViewGen - - -class VisualGen(ViewGen): - def fit(): - raise NotImplemented - - def transform(self, lX): - return super().transform(lX) - - def fit_transform(self, lX, lY): - return super().fit_transform(lX, lY) - - def save_vgf(self, model_id): - return super().save_vgf(model_id) - - def save_vgf(self, model_id): - return super().save_vgf(model_id) diff --git a/gfun/vgfs/visualTransformerGen.py b/gfun/vgfs/visualTransformerGen.py new file mode 100644 index 0000000..aaa7651 --- /dev/null +++ b/gfun/vgfs/visualTransformerGen.py @@ -0,0 +1,175 @@ +import sys, os + +sys.path.append(os.getcwd()) + +import torch +import transformers +from gfun.vgfs.viewGen import ViewGen +from transformers import AutoImageProcessor +from torch.utils.data import DataLoader, Dataset +from torchvision.transforms import RandomResizedCrop, Compose, Normalize, ToTensor +from gfun.vgfs.commons import Trainer, predict +from gfun.vgfs.transformerGen import TransformerGen +from transformers import AutoModelForImageClassification, TrainingArguments, Trainer + +transformers.logging.set_verbosity_error() + + +class VisualTransformerGen(ViewGen, TransformerGen): + def __init__( + self, model_name, lr=1e-5, epochs=10, batch_size=32, batch_size_eval=128 + ): + self.model_name = model_name + self.datasets = {} + self.lr = lr + self.epochs = epochs + self.batch_size = batch_size + self.batch_size_eval = batch_size_eval + + def _validate_model_name(self, model_name): + if "vit" == model_name: + return "google/vit-base-patch16-224-in21k" + else: + raise NotImplementedError + + def init_model(self, model_name, num_labels): + model = ( + AutoModelForImageClassification.from_pretrained( + model_name, num_labels=num_labels + ), + ) + image_processor = AutoImageProcessor.from_pretrained(model_name) + transforms = self.init_preprocessor(image_processor) + return model, image_processor, transforms + + def init_preprocessor(self, image_processor): + normalize = Normalize( + mean=image_processor.image_mean, std=image_processor.image_std + ) + size = ( + image_processor.size["shortest_edge"] + if "shortest_edge" in image_processor.size + else (image_processor.size["height"], image_processor.size["width"]) + ) + # these are the transformations that we are applying to the images + transforms = Compose([RandomResizedCrop(size), ToTensor(), normalize]) + return transforms + + def preprocess(self, images, transforms): + processed = transforms(img.convert("RGB") for img in images) + return processed + + def process_all(self, X): + # TODO: every element in X is a tuple (doc_id, clean_text, text, Pil.Image), so we're taking just the last element for processing + processed = torch.stack([self.transforms(img[-1]) for img in X]) + return processed + + def fit(self, lX, lY): + print("- fitting Visual Transformer View Generating Function") + _l = list(lX.keys())[0] + self.num_labels = lY[_l].shape[-1] + self.model, self.image_preprocessor, self.transforms = self.init_model( + self._validate_model_name(self.model_name), num_labels=self.num_labels + ) + + tr_lX, tr_lY, val_lX, val_lY = self.get_train_val_data( + lX, lY, split=0.2, seed=42 + ) + + tra_dataloader = self.build_dataloader( + tr_lX, + tr_lY, + processor_fn=self.process_all, + torchDataset=MultimodalDatasetTorch, + batch_size=self.batch_size, + split="train", + shuffle=True, + ) + + val_dataloader = self.build_dataloader( + val_lX, + val_lY, + processor_fn=self.process_all, + torchDataset=MultimodalDatasetTorch, + batch_size=self.batch_size_eval, + split="val", + shuffle=False, + ) + + experiment_name = f"{self.model_name}-{self.epochs}-{self.batch_size}" + trainer = Trainer( + model=self.model, + optimizer_name="adamW", + lr=self.lr, + device=self.device, + loss_fn=torch.nn.CrossEntropyLoss(), + print_steps=self.print_steps, + evaluate_step=self.evaluate_step, + patience=self.patience, + experiment_name=experiment_name, + ) + + trainer.train( + train_dataloader=tra_dataloader, + val_dataloader=val_dataloader, + epochs=self.epochs, + ) + + def transform(self, lX): + raise NotImplementedError + + def fit_transform(self, lX, lY): + raise NotImplementedError + + def save_vgf(self, model_id): + raise NotImplementedError + + def save_vgf(self, model_id): + raise NotImplementedError + + +class MultimodalDatasetTorch(Dataset): + def __init__(self, lX, lY, split="train"): + self.lX = lX + self.lY = lY + self.split = split + self.langs = [] + self.init() + + def init(self): + self.X = torch.vstack([imgs for imgs in self.lX.values()]) + if self.split != "whole": + self.Y = torch.vstack([torch.Tensor(data) for data in self.lY.values()]) + self.langs = sum( + [ + v + for v in { + lang: [lang] * len(data) for lang, data in self.lX.items() + }.values() + ], + [], + ) + print(f"- lX has shape: {self.X.shape}\n- lY has shape: {self.Y.shape}") + + def __len__(self): + return len(self.X) + + def __getitem__(self, index): + if self.split == "whole": + return self.X[index], self.langs[index] + return self.X[index], self.Y[index], self.langs[index] + + +if __name__ == "__main__": + from os.path import expanduser + from dataManager.multiNewsDataset import MultiNewsDataset + + _dataset_path_hardcoded = "~/datasets/MultiNews/20110730/" + + dataset = MultiNewsDataset(expanduser(_dataset_path_hardcoded), debug=True) + lXtr, lYtr = dataset.training() + + vg = VisualTransformerGen(model_name="vit") + lX, lY = dataset.training() + vg.fit(lX, lY) + print("lel")