Visual Transformer VGF

This commit is contained in:
Andrea Pedrotti 2023-02-09 16:55:06 +01:00
parent 4485d97e03
commit dba2ed9c9c
6 changed files with 395 additions and 42 deletions

View File

@ -49,7 +49,7 @@ class MultiNewsDataset:
from os import listdir from os import listdir
if self.debug: if self.debug:
return ["it"] return ["it", "en"]
return tuple(sorted([folder for folder in listdir(self.data_dir)])) return tuple(sorted([folder for folder in listdir(self.data_dir)]))
@ -67,7 +67,7 @@ class MultiNewsDataset:
def _count_lang_labels(self, labels): def _count_lang_labels(self, labels):
lang_labels = set() lang_labels = set()
for l in labels: for l in labels:
lang_labels.update(l[-1]) lang_labels.update(l)
return len(lang_labels) return len(lang_labels)
def export_to_torch_dataset(self, tokenizer_id): def export_to_torch_dataset(self, tokenizer_id):
@ -125,11 +125,14 @@ class MultiModalDataset:
with open(join(self.data_dir, news_folder, fname_doc)) as f: with open(join(self.data_dir, news_folder, fname_doc)) as f:
html_doc = f.read() html_doc = f.read()
index_path = join(self.data_dir, news_folder, "index.html") index_path = join(self.data_dir, news_folder, "index.html")
if ".jpg" not in listdir(join(self.data_dir, news_folder)): if not any(
File.endswith(".jpg")
for File in listdir(join(self.data_dir, news_folder))
):
img_link, img = self.get_images(index_path) img_link, img = self.get_images(index_path)
self.save_img(join(self.data_dir, news_folder, "img.jpg"), img) self.save_img(join(self.data_dir, news_folder, "img.jpg"), img)
else: # TODO: convert img to PIL image
img = Image.open(join(self.data_dir, news_folder, "img.jpg")) img = Image.open(join(self.data_dir, news_folder, "img.jpg"))
clean_doc, doc_labels = self.preprocess_html(html_doc) clean_doc, doc_labels = self.preprocess_html(html_doc)
data.append((fname_doc, clean_doc, html_doc, img)) data.append((fname_doc, clean_doc, html_doc, img))
labels.append(doc_labels) labels.append(doc_labels)

View File

@ -1,7 +1,14 @@
from sklearn.preprocessing import normalize import os
from sklearn.feature_extraction.text import TfidfVectorizer from collections import defaultdict
from sklearn.decomposition import TruncatedSVD
import numpy as np import numpy as np
import torch
from sklearn.decomposition import TruncatedSVD
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.preprocessing import normalize
from torch.optim import AdamW
from evaluation.evaluate import evaluate, log_eval
def _normalize(lX, l2=True): def _normalize(lX, l2=True):
@ -30,6 +37,34 @@ def remove_pc(X, npc=1):
return XX return XX
def compute_pc(X, npc=1):
"""
Compute the principal components.
:param X: X[i,:] is a data point
:param npc: number of principal components to remove
:return: component_[i,:] is the i-th pc
"""
if isinstance(X, np.matrix):
X = np.asarray(X)
svd = TruncatedSVD(n_components=npc, n_iter=7, random_state=0)
svd.fit(X)
return svd.components_
def predict(logits, classification_type="multilabel"):
"""
Converts soft precictions to hard predictions [0,1]
"""
if classification_type == "multilabel":
prediction = torch.sigmoid(logits) > 0.5
elif classification_type == "singlelabel":
prediction = torch.argmax(logits, dim=1).view(-1, 1)
else:
print("unknown classification type")
return prediction.detach().cpu().numpy()
class TfidfVectorizerMultilingual: class TfidfVectorizerMultilingual:
def __init__(self, **kwargs): def __init__(self, **kwargs):
self.kwargs = kwargs self.kwargs = kwargs
@ -60,15 +95,130 @@ class TfidfVectorizerMultilingual:
return self.vectorizer[l].build_analyzer() return self.vectorizer[l].build_analyzer()
def compute_pc(X, npc=1): class Trainer:
""" def __init__(
Compute the principal components. self,
:param X: X[i,:] is a data point model,
:param npc: number of principal components to remove optimizer_name,
:return: component_[i,:] is the i-th pc device,
""" loss_fn,
if isinstance(X, np.matrix): lr,
X = np.asarray(X) print_steps,
svd = TruncatedSVD(n_components=npc, n_iter=7, random_state=0) evaluate_step,
svd.fit(X) patience,
return svd.components_ experiment_name,
):
self.device = device
self.model = model.to(device)
self.optimizer = self.init_optimizer(optimizer_name, lr)
self.evaluate_steps = evaluate_step
self.loss_fn = loss_fn.to(device)
self.print_steps = print_steps
self.earlystopping = EarlyStopping(
patience=patience,
checkpoint_path="models/vgfs/transformers/",
verbose=True,
experiment_name=experiment_name,
)
def init_optimizer(self, optimizer_name, lr):
if optimizer_name.lower() == "adamw":
return AdamW(self.model.parameters(), lr=lr)
else:
raise ValueError(f"Optimizer {optimizer_name} not supported")
def train(self, train_dataloader, eval_dataloader, epochs=10):
print(
f"""- Training params:
- epochs: {epochs}
- learning rate: {self.optimizer.defaults['lr']}
- train batch size: {train_dataloader.batch_size}
- eval batch size: {eval_dataloader.batch_size}
- max len: {train_dataloader.dataset.X.shape[-1]}\n""",
)
for epoch in range(epochs):
self.train_epoch(train_dataloader, epoch)
if (epoch + 1) % self.evaluate_steps == 0:
metric_watcher = self.evaluate(eval_dataloader)
stop = self.earlystopping(metric_watcher, self.model, epoch + 1)
if stop:
break
return self.model
def train_epoch(self, dataloader, epoch):
self.model.train()
for b_idx, (x, y, lang) in enumerate(dataloader):
self.optimizer.zero_grad()
y_hat = self.model(x.to(self.device))
loss = self.loss_fn(y_hat.logits, y.to(self.device))
loss.backward()
self.optimizer.step()
if b_idx % self.print_steps == 0:
print(f"Epoch: {epoch+1} Step: {b_idx+1} Loss: {loss:.4f}")
return self
def evaluate(self, dataloader):
self.model.eval()
lY = defaultdict(list)
lY_hat = defaultdict(list)
for b_idx, (x, y, lang) in enumerate(dataloader):
y_hat = self.model(x.to(self.device))
loss = self.loss_fn(y_hat.logits, y.to(self.device))
predictions = predict(y_hat.logits, classification_type="multilabel")
for l, _true, _pred in zip(lang, y, predictions):
lY[l].append(_true.detach().cpu().numpy())
lY_hat[l].append(_pred)
for lang in lY:
lY[lang] = np.vstack(lY[lang])
lY_hat[lang] = np.vstack(lY_hat[lang])
l_eval = evaluate(lY, lY_hat)
average_metrics = log_eval(l_eval, phase="validation")
return average_metrics[0] # macro-F1
class EarlyStopping:
def __init__(
self,
patience=5,
min_delta=0,
verbose=True,
checkpoint_path="checkpoint.pt",
experiment_name="experiment",
):
self.patience = patience
self.min_delta = min_delta
self.counter = 0
self.best_score = 0
self.best_epoch = None
self.verbose = verbose
self.checkpoint_path = checkpoint_path
self.experiment_name = experiment_name
def __call__(self, validation, model, epoch):
if validation > self.best_score:
print(
f"- earlystopping: Validation score improved from {self.best_score:.3f} to {validation:.3f}"
)
self.best_score = validation
self.counter = 0
# self.save_model(model)
elif validation < (self.best_score + self.min_delta):
self.counter += 1
print(
f"- earlystopping: Validation score decreased from {self.best_score:.3f} to {validation:.3f}, current patience: {self.patience - self.counter}"
)
if self.counter >= self.patience:
if self.verbose:
print(f"- earlystopping: Early stopping at epoch {epoch}")
return True
def save_model(self, model):
_checkpoint_dir = os.path.join(self.checkpoint_path, self.experiment_name)
print(f"- saving model to {_checkpoint_dir}")
os.makedirs(_checkpoint_dir, exist_ok=True)
model.save_pretrained(_checkpoint_dir)

View File

@ -12,6 +12,7 @@ from torch.optim import AdamW
from torch.utils.data import DataLoader, Dataset from torch.utils.data import DataLoader, Dataset
from transformers import AutoModelForSequenceClassification, AutoTokenizer from transformers import AutoModelForSequenceClassification, AutoTokenizer
from vgfs.learners.svms import FeatureSet2Posteriors from vgfs.learners.svms import FeatureSet2Posteriors
from vgfs.viewGen import ViewGen
from evaluation.evaluate import evaluate, log_eval from evaluation.evaluate import evaluate, log_eval
@ -19,9 +20,10 @@ transformers.logging.set_verbosity_error()
# TODO: add support to loggers # TODO: add support to loggers
# TODO: multiple inheritance - maybe define a superclass for TransformerGenerator, whether it is a Textual or a Visual one, implementing dataset creation functions
class TextualTransformerGen: class TextualTransformerGen(ViewGen):
def __init__( def __init__(
self, self,
model_name, model_name,
@ -105,7 +107,7 @@ class TextualTransformerGen:
return tr_lX, tr_lY, val_lX, val_lY return tr_lX, tr_lY, val_lX, val_lY
def build_dataloader(self, lX, lY, batch_size, split="train", shuffle=True): def build_dataloader(self, lX, lY, batch_size, split="train", shuffle=False):
l_tokenized = {lang: self._tokenize(data) for lang, data in lX.items()} l_tokenized = {lang: self._tokenize(data) for lang, data in lX.items()}
self.datasets[split] = MultilingualDatasetTorch(l_tokenized, lY, split=split) self.datasets[split] = MultilingualDatasetTorch(l_tokenized, lY, split=split)
return DataLoader(self.datasets[split], batch_size=batch_size, shuffle=shuffle) return DataLoader(self.datasets[split], batch_size=batch_size, shuffle=shuffle)
@ -122,7 +124,7 @@ class TextualTransformerGen:
def fit(self, lX, lY): def fit(self, lX, lY):
if self.fitted: if self.fitted:
return self return self
print("- fitting Transformer View Generating Function") print("- fitting Textual Transformer View Generating Function")
_l = list(lX.keys())[0] _l = list(lX.keys())[0]
self.num_labels = lY[_l].shape[-1] self.num_labels = lY[_l].shape[-1]
self.model, self.tokenizer = self.init_model( self.model, self.tokenizer = self.init_model(
@ -196,8 +198,8 @@ class TextualTransformerGen:
def save_vgf(self, model_id): def save_vgf(self, model_id):
import pickle import pickle
from os.path import join
from os import makedirs from os import makedirs
from os.path import join
vgf_name = "transformerGen" vgf_name = "transformerGen"
_basedir = join("models", "vgfs", "transformer") _basedir = join("models", "vgfs", "transformer")

View File

@ -0,0 +1,41 @@
from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset, DataLoader
class TransformerGen:
"""Base class for all transformers. It implements the basic methods for
the creation of the datasets, datalaoders and the train-val split method.
It is designed to be used with MultilingualDataset in the
form of dictioanries {lang: data}
"""
def __init__(self):
self.datasets = {}
def build_dataloader(
self,
lX,
lY,
torchDataset,
processor_fn,
batch_size,
split="train",
shuffle=False,
):
l_tokenized = {lang: processor_fn(data) for lang, data in lX.items()}
self.datasets[split] = torchDataset(l_tokenized, lY, split=split)
return DataLoader(self.datasets[split], batch_size=batch_size, shuffle=shuffle)
def get_train_val_data(self, lX, lY, split=0.2, seed=42):
tr_lX, tr_lY, val_lX, val_lY = {}, {}, {}, {}
for lang in lX.keys():
tr_X, val_X, tr_Y, val_Y = train_test_split(
lX[lang], lY[lang], test_size=split, random_state=seed, shuffle=False
)
tr_lX[lang] = tr_X
tr_lY[lang] = tr_Y
val_lX[lang] = val_X
val_lY[lang] = val_Y
return tr_lX, tr_lY, val_lX, val_lY

View File

@ -1,18 +0,0 @@
from vgfs.viewGen import ViewGen
class VisualGen(ViewGen):
def fit():
raise NotImplemented
def transform(self, lX):
return super().transform(lX)
def fit_transform(self, lX, lY):
return super().fit_transform(lX, lY)
def save_vgf(self, model_id):
return super().save_vgf(model_id)
def save_vgf(self, model_id):
return super().save_vgf(model_id)

View File

@ -0,0 +1,175 @@
import sys, os
sys.path.append(os.getcwd())
import torch
import transformers
from gfun.vgfs.viewGen import ViewGen
from transformers import AutoImageProcessor
from torch.utils.data import DataLoader, Dataset
from torchvision.transforms import RandomResizedCrop, Compose, Normalize, ToTensor
from gfun.vgfs.commons import Trainer, predict
from gfun.vgfs.transformerGen import TransformerGen
from transformers import AutoModelForImageClassification, TrainingArguments, Trainer
transformers.logging.set_verbosity_error()
class VisualTransformerGen(ViewGen, TransformerGen):
def __init__(
self, model_name, lr=1e-5, epochs=10, batch_size=32, batch_size_eval=128
):
self.model_name = model_name
self.datasets = {}
self.lr = lr
self.epochs = epochs
self.batch_size = batch_size
self.batch_size_eval = batch_size_eval
def _validate_model_name(self, model_name):
if "vit" == model_name:
return "google/vit-base-patch16-224-in21k"
else:
raise NotImplementedError
def init_model(self, model_name, num_labels):
model = (
AutoModelForImageClassification.from_pretrained(
model_name, num_labels=num_labels
),
)
image_processor = AutoImageProcessor.from_pretrained(model_name)
transforms = self.init_preprocessor(image_processor)
return model, image_processor, transforms
def init_preprocessor(self, image_processor):
normalize = Normalize(
mean=image_processor.image_mean, std=image_processor.image_std
)
size = (
image_processor.size["shortest_edge"]
if "shortest_edge" in image_processor.size
else (image_processor.size["height"], image_processor.size["width"])
)
# these are the transformations that we are applying to the images
transforms = Compose([RandomResizedCrop(size), ToTensor(), normalize])
return transforms
def preprocess(self, images, transforms):
processed = transforms(img.convert("RGB") for img in images)
return processed
def process_all(self, X):
# TODO: every element in X is a tuple (doc_id, clean_text, text, Pil.Image), so we're taking just the last element for processing
processed = torch.stack([self.transforms(img[-1]) for img in X])
return processed
def fit(self, lX, lY):
print("- fitting Visual Transformer View Generating Function")
_l = list(lX.keys())[0]
self.num_labels = lY[_l].shape[-1]
self.model, self.image_preprocessor, self.transforms = self.init_model(
self._validate_model_name(self.model_name), num_labels=self.num_labels
)
tr_lX, tr_lY, val_lX, val_lY = self.get_train_val_data(
lX, lY, split=0.2, seed=42
)
tra_dataloader = self.build_dataloader(
tr_lX,
tr_lY,
processor_fn=self.process_all,
torchDataset=MultimodalDatasetTorch,
batch_size=self.batch_size,
split="train",
shuffle=True,
)
val_dataloader = self.build_dataloader(
val_lX,
val_lY,
processor_fn=self.process_all,
torchDataset=MultimodalDatasetTorch,
batch_size=self.batch_size_eval,
split="val",
shuffle=False,
)
experiment_name = f"{self.model_name}-{self.epochs}-{self.batch_size}"
trainer = Trainer(
model=self.model,
optimizer_name="adamW",
lr=self.lr,
device=self.device,
loss_fn=torch.nn.CrossEntropyLoss(),
print_steps=self.print_steps,
evaluate_step=self.evaluate_step,
patience=self.patience,
experiment_name=experiment_name,
)
trainer.train(
train_dataloader=tra_dataloader,
val_dataloader=val_dataloader,
epochs=self.epochs,
)
def transform(self, lX):
raise NotImplementedError
def fit_transform(self, lX, lY):
raise NotImplementedError
def save_vgf(self, model_id):
raise NotImplementedError
def save_vgf(self, model_id):
raise NotImplementedError
class MultimodalDatasetTorch(Dataset):
def __init__(self, lX, lY, split="train"):
self.lX = lX
self.lY = lY
self.split = split
self.langs = []
self.init()
def init(self):
self.X = torch.vstack([imgs for imgs in self.lX.values()])
if self.split != "whole":
self.Y = torch.vstack([torch.Tensor(data) for data in self.lY.values()])
self.langs = sum(
[
v
for v in {
lang: [lang] * len(data) for lang, data in self.lX.items()
}.values()
],
[],
)
print(f"- lX has shape: {self.X.shape}\n- lY has shape: {self.Y.shape}")
def __len__(self):
return len(self.X)
def __getitem__(self, index):
if self.split == "whole":
return self.X[index], self.langs[index]
return self.X[index], self.Y[index], self.langs[index]
if __name__ == "__main__":
from os.path import expanduser
from dataManager.multiNewsDataset import MultiNewsDataset
_dataset_path_hardcoded = "~/datasets/MultiNews/20110730/"
dataset = MultiNewsDataset(expanduser(_dataset_path_hardcoded), debug=True)
lXtr, lYtr = dataset.training()
vg = VisualTransformerGen(model_name="vit")
lX, lY = dataset.training()
vg.fit(lX, lY)
print("lel")