187 lines
5.8 KiB
Python
187 lines
5.8 KiB
Python
import sys, os
|
|
|
|
sys.path.append(os.getcwd())
|
|
|
|
import torch
|
|
import transformers
|
|
from gfun.vgfs.viewGen import ViewGen
|
|
from transformers import AutoImageProcessor
|
|
from torch.utils.data import DataLoader, Dataset
|
|
from torchvision.transforms import RandomResizedCrop, Compose, Normalize, ToTensor
|
|
from gfun.vgfs.commons import Trainer, predict
|
|
from gfun.vgfs.transformerGen import TransformerGen
|
|
from transformers import AutoModelForImageClassification
|
|
|
|
transformers.logging.set_verbosity_error()
|
|
|
|
|
|
class VisualTransformerGen(ViewGen, TransformerGen):
|
|
def __init__(
|
|
self,
|
|
model_name,
|
|
lr=1e-5,
|
|
epochs=10,
|
|
batch_size=32,
|
|
batch_size_eval=128,
|
|
evaluate_step=10,
|
|
device="cpu",
|
|
patience=5,
|
|
):
|
|
super().__init__(
|
|
model_name,
|
|
lr=lr,
|
|
epochs=epochs,
|
|
batch_size=batch_size,
|
|
batch_size_eval=batch_size_eval,
|
|
device=device,
|
|
evaluate_step=evaluate_step,
|
|
patience=patience,
|
|
)
|
|
|
|
def _validate_model_name(self, model_name):
|
|
if "vit" == model_name:
|
|
return "google/vit-base-patch16-224-in21k"
|
|
else:
|
|
raise NotImplementedError
|
|
|
|
def init_model(self, model_name, num_labels):
|
|
model = AutoModelForImageClassification.from_pretrained(
|
|
model_name, num_labels=num_labels
|
|
)
|
|
image_processor = AutoImageProcessor.from_pretrained(model_name)
|
|
transforms = self.init_preprocessor(image_processor)
|
|
return model, image_processor, transforms
|
|
|
|
def init_preprocessor(self, image_processor):
|
|
normalize = Normalize(
|
|
mean=image_processor.image_mean, std=image_processor.image_std
|
|
)
|
|
size = (
|
|
image_processor.size["shortest_edge"]
|
|
if "shortest_edge" in image_processor.size
|
|
else (image_processor.size["height"], image_processor.size["width"])
|
|
)
|
|
# these are the transformations that we are applying to the images
|
|
transforms = Compose([RandomResizedCrop(size), ToTensor(), normalize])
|
|
return transforms
|
|
|
|
def preprocess(self, images, transforms):
|
|
processed = transforms(img.convert("RGB") for img in images)
|
|
return processed
|
|
|
|
def process_all(self, X):
|
|
# TODO: every element in X is a tuple (doc_id, clean_text, text, Pil.Image), so we're taking just the last element for processing
|
|
processed = torch.stack([self.transforms(img[-1]) for img in X])
|
|
return processed
|
|
|
|
def fit(self, lX, lY):
|
|
print("- fitting Visual Transformer View Generating Function")
|
|
_l = list(lX.keys())[0]
|
|
self.num_labels = lY[_l].shape[-1]
|
|
self.model, self.image_preprocessor, self.transforms = self.init_model(
|
|
self._validate_model_name(self.model_name), num_labels=self.num_labels
|
|
)
|
|
|
|
tr_lX, tr_lY, val_lX, val_lY = self.get_train_val_data(
|
|
lX, lY, split=0.2, seed=42
|
|
)
|
|
|
|
tra_dataloader = self.build_dataloader(
|
|
tr_lX,
|
|
tr_lY,
|
|
processor_fn=self.process_all,
|
|
torchDataset=MultimodalDatasetTorch,
|
|
batch_size=self.batch_size,
|
|
split="train",
|
|
shuffle=True,
|
|
)
|
|
|
|
val_dataloader = self.build_dataloader(
|
|
val_lX,
|
|
val_lY,
|
|
processor_fn=self.process_all,
|
|
torchDataset=MultimodalDatasetTorch,
|
|
batch_size=self.batch_size_eval,
|
|
split="val",
|
|
shuffle=False,
|
|
)
|
|
|
|
experiment_name = f"{self.model_name}-{self.epochs}-{self.batch_size}"
|
|
trainer = Trainer(
|
|
model=self.model,
|
|
optimizer_name="adamW",
|
|
device=self.device,
|
|
loss_fn=torch.nn.CrossEntropyLoss(),
|
|
lr=self.lr,
|
|
print_steps=self.print_steps,
|
|
evaluate_step=self.evaluate_step,
|
|
patience=self.patience,
|
|
experiment_name=experiment_name,
|
|
)
|
|
|
|
trainer.train(
|
|
train_dataloader=tra_dataloader,
|
|
eval_dataloader=val_dataloader,
|
|
epochs=self.epochs,
|
|
)
|
|
|
|
def transform(self, lX):
|
|
raise NotImplementedError
|
|
|
|
def fit_transform(self, lX, lY):
|
|
raise NotImplementedError
|
|
|
|
def save_vgf(self, model_id):
|
|
raise NotImplementedError
|
|
|
|
def save_vgf(self, model_id):
|
|
raise NotImplementedError
|
|
|
|
|
|
class MultimodalDatasetTorch(Dataset):
|
|
def __init__(self, lX, lY, split="train"):
|
|
self.lX = lX
|
|
self.lY = lY
|
|
self.split = split
|
|
self.langs = []
|
|
self.init()
|
|
|
|
def init(self):
|
|
self.X = torch.vstack([imgs for imgs in self.lX.values()])
|
|
if self.split != "whole":
|
|
self.Y = torch.vstack([torch.Tensor(data) for data in self.lY.values()])
|
|
self.langs = sum(
|
|
[
|
|
v
|
|
for v in {
|
|
lang: [lang] * len(data) for lang, data in self.lX.items()
|
|
}.values()
|
|
],
|
|
[],
|
|
)
|
|
# print(f"- lX has shape: {self.X.shape}\n- lY has shape: {self.Y.shape}")
|
|
|
|
def __len__(self):
|
|
return len(self.X)
|
|
|
|
def __getitem__(self, index):
|
|
if self.split == "whole":
|
|
return self.X[index], self.langs[index]
|
|
return self.X[index], self.Y[index], self.langs[index]
|
|
|
|
|
|
if __name__ == "__main__":
|
|
from os.path import expanduser
|
|
from dataManager.multiNewsDataset import MultiNewsDataset
|
|
|
|
_dataset_path_hardcoded = "~/datasets/MultiNews/20110730/"
|
|
|
|
dataset = MultiNewsDataset(expanduser(_dataset_path_hardcoded), debug=True)
|
|
lXtr, lYtr = dataset.training()
|
|
|
|
vg = VisualTransformerGen(
|
|
model_name="vit", device="cuda", epochs=1000, evaluate_step=10, patience=100
|
|
)
|
|
lX, lY = dataset.training()
|
|
vg.fit(lX, lY)
|