gfun_multimodal/gfun/vgfs/visualTransformerGen.py

187 lines
5.8 KiB
Python

import sys, os
sys.path.append(os.getcwd())
import torch
import transformers
from gfun.vgfs.viewGen import ViewGen
from transformers import AutoImageProcessor
from torch.utils.data import DataLoader, Dataset
from torchvision.transforms import RandomResizedCrop, Compose, Normalize, ToTensor
from gfun.vgfs.commons import Trainer, predict
from gfun.vgfs.transformerGen import TransformerGen
from transformers import AutoModelForImageClassification
transformers.logging.set_verbosity_error()
class VisualTransformerGen(ViewGen, TransformerGen):
def __init__(
self,
model_name,
lr=1e-5,
epochs=10,
batch_size=32,
batch_size_eval=128,
evaluate_step=10,
device="cpu",
patience=5,
):
super().__init__(
model_name,
lr=lr,
epochs=epochs,
batch_size=batch_size,
batch_size_eval=batch_size_eval,
device=device,
evaluate_step=evaluate_step,
patience=patience,
)
def _validate_model_name(self, model_name):
if "vit" == model_name:
return "google/vit-base-patch16-224-in21k"
else:
raise NotImplementedError
def init_model(self, model_name, num_labels):
model = AutoModelForImageClassification.from_pretrained(
model_name, num_labels=num_labels
)
image_processor = AutoImageProcessor.from_pretrained(model_name)
transforms = self.init_preprocessor(image_processor)
return model, image_processor, transforms
def init_preprocessor(self, image_processor):
normalize = Normalize(
mean=image_processor.image_mean, std=image_processor.image_std
)
size = (
image_processor.size["shortest_edge"]
if "shortest_edge" in image_processor.size
else (image_processor.size["height"], image_processor.size["width"])
)
# these are the transformations that we are applying to the images
transforms = Compose([RandomResizedCrop(size), ToTensor(), normalize])
return transforms
def preprocess(self, images, transforms):
processed = transforms(img.convert("RGB") for img in images)
return processed
def process_all(self, X):
# TODO: every element in X is a tuple (doc_id, clean_text, text, Pil.Image), so we're taking just the last element for processing
processed = torch.stack([self.transforms(img[-1]) for img in X])
return processed
def fit(self, lX, lY):
print("- fitting Visual Transformer View Generating Function")
_l = list(lX.keys())[0]
self.num_labels = lY[_l].shape[-1]
self.model, self.image_preprocessor, self.transforms = self.init_model(
self._validate_model_name(self.model_name), num_labels=self.num_labels
)
tr_lX, tr_lY, val_lX, val_lY = self.get_train_val_data(
lX, lY, split=0.2, seed=42
)
tra_dataloader = self.build_dataloader(
tr_lX,
tr_lY,
processor_fn=self.process_all,
torchDataset=MultimodalDatasetTorch,
batch_size=self.batch_size,
split="train",
shuffle=True,
)
val_dataloader = self.build_dataloader(
val_lX,
val_lY,
processor_fn=self.process_all,
torchDataset=MultimodalDatasetTorch,
batch_size=self.batch_size_eval,
split="val",
shuffle=False,
)
experiment_name = f"{self.model_name}-{self.epochs}-{self.batch_size}"
trainer = Trainer(
model=self.model,
optimizer_name="adamW",
device=self.device,
loss_fn=torch.nn.CrossEntropyLoss(),
lr=self.lr,
print_steps=self.print_steps,
evaluate_step=self.evaluate_step,
patience=self.patience,
experiment_name=experiment_name,
)
trainer.train(
train_dataloader=tra_dataloader,
eval_dataloader=val_dataloader,
epochs=self.epochs,
)
def transform(self, lX):
raise NotImplementedError
def fit_transform(self, lX, lY):
raise NotImplementedError
def save_vgf(self, model_id):
raise NotImplementedError
def save_vgf(self, model_id):
raise NotImplementedError
class MultimodalDatasetTorch(Dataset):
def __init__(self, lX, lY, split="train"):
self.lX = lX
self.lY = lY
self.split = split
self.langs = []
self.init()
def init(self):
self.X = torch.vstack([imgs for imgs in self.lX.values()])
if self.split != "whole":
self.Y = torch.vstack([torch.Tensor(data) for data in self.lY.values()])
self.langs = sum(
[
v
for v in {
lang: [lang] * len(data) for lang, data in self.lX.items()
}.values()
],
[],
)
# print(f"- lX has shape: {self.X.shape}\n- lY has shape: {self.Y.shape}")
def __len__(self):
return len(self.X)
def __getitem__(self, index):
if self.split == "whole":
return self.X[index], self.langs[index]
return self.X[index], self.Y[index], self.langs[index]
if __name__ == "__main__":
from os.path import expanduser
from dataManager.multiNewsDataset import MultiNewsDataset
_dataset_path_hardcoded = "~/datasets/MultiNews/20110730/"
dataset = MultiNewsDataset(expanduser(_dataset_path_hardcoded), debug=True)
lXtr, lYtr = dataset.training()
vg = VisualTransformerGen(
model_name="vit", device="cuda", epochs=1000, evaluate_step=10, patience=100
)
lX, lY = dataset.training()
vg.fit(lX, lY)