gfun_multimodal/gfun/vgfs/visualTransformerGen.py

194 lines
6.3 KiB
Python

from collections import defaultdict
import numpy as np
import torch
import transformers
from PIL import Image
from transformers import AutoImageProcessor, AutoModelForImageClassification
from gfun.vgfs.commons import Trainer
from gfun.vgfs.transformerGen import TransformerGen
from gfun.vgfs.viewGen import ViewGen
from dataManager.torchDataset import MultimodalDatasetTorch
transformers.logging.set_verbosity_error()
class VisualTransformerGen(ViewGen, TransformerGen):
def __init__(
self,
model_name,
dataset_name,
lr=1e-5,
scheduler="ReduceLROnPlateau",
epochs=10,
batch_size=32,
batch_size_eval=128,
evaluate_step=10,
device="cpu",
probabilistic=False,
patience=5,
classification_type="multilabel",
):
super().__init__(
model_name,
dataset_name,
epochs=epochs,
lr=lr,
scheduler=scheduler,
batch_size=batch_size,
batch_size_eval=batch_size_eval,
device=device,
evaluate_step=evaluate_step,
patience=patience,
probabilistic=probabilistic,
)
self.clf_type = classification_type
self.fitted = False
print(
f"- init Visual TransformerModel model_name: {self.model_name}, device: {self.device}]"
)
def _validate_model_name(self, model_name):
if "vit" == model_name:
return "google/vit-base-patch16-224-in21k"
else:
raise NotImplementedError
def init_model(self, model_name, num_labels):
model = AutoModelForImageClassification.from_pretrained(
model_name, num_labels=num_labels, output_hidden_states=True
)
image_processor = AutoImageProcessor.from_pretrained(model_name)
return model, image_processor
def process_all(self, X):
# TODO: should be moved as a collate_fn to avoid this overhead
processed = self.image_preprocessor(
[Image.open(img).convert("RGB") for img in X], return_tensors="pt"
)
return processed["pixel_values"]
def fit(self, lX, lY):
print("- fitting Visual Transformer View Generating Function")
_l = list(lX.keys())[0]
self.num_labels = lY[_l].shape[-1]
self.model, self.image_preprocessor = self.init_model(
self._validate_model_name(self.model_name), num_labels=self.num_labels
)
tr_lX, tr_lY, val_lX, val_lY = self.get_train_val_data(
lX, lY, split=0.2, seed=42, modality="image"
)
tra_dataloader = self.build_dataloader(
tr_lX,
tr_lY,
processor_fn=self.process_all,
torchDataset=MultimodalDatasetTorch,
batch_size=self.batch_size,
split="train",
shuffle=True,
)
val_dataloader = self.build_dataloader(
val_lX,
val_lY,
processor_fn=self.process_all,
torchDataset=MultimodalDatasetTorch,
batch_size=self.batch_size_eval,
split="val",
shuffle=False,
)
experiment_name = (
f"{self.model_name}-{self.epochs}-{self.batch_size}-{self.dataset_name}"
)
trainer = Trainer(
model=self.model,
optimizer_name="adamW",
device=self.device,
loss_fn=torch.nn.CrossEntropyLoss(),
lr=self.lr,
print_steps=self.print_steps,
evaluate_step=self.evaluate_step,
patience=self.patience,
experiment_name=experiment_name,
checkpoint_path="models/vgfs/transformer",
vgf_name="visual_trf",
classification_type=self.clf_type,
n_jobs=self.n_jobs,
)
trainer.train(
train_dataloader=tra_dataloader,
eval_dataloader=val_dataloader,
epochs=self.epochs,
)
if self.probabilistic:
self.feature2posterior_projector.fit(self.transform(lX), lY)
self.fitted = True
return self
def transform(self, lX):
# forcing to only image modality
lX = {lang: data["image"] for lang, data in lX.items()}
_embeds = []
l_embeds = defaultdict(list)
dataloader = self.build_dataloader(
lX,
lY=None,
processor_fn=self.process_all,
torchDataset=MultimodalDatasetTorch,
batch_size=self.batch_size_eval,
split="whole",
shuffle=False,
)
self.model.eval()
with torch.no_grad():
for input_ids, lang in dataloader:
input_ids = input_ids.to(self.device)
out = self.model(input_ids).hidden_states[-1]
batch_embeddings = out[:, 0, :].cpu().numpy()
_embeds.append((batch_embeddings, lang))
for embed, lang in _embeds:
for sample_embed, sample_lang in zip(embed, lang):
l_embeds[sample_lang].append(sample_embed)
if self.probabilistic and self.fitted:
l_embeds = self.feature2posterior_projector.transform(l_embeds)
elif not self.probabilistic and self.fitted:
l_embeds = {lang: np.array(preds) for lang, preds in l_embeds.items()}
return l_embeds
def fit_transform(self, lX, lY):
return self.fit(lX, lY).transform(lX)
def save_vgf(self, model_id):
import pickle
from os import makedirs
from os.path import join
vgf_name = "visualTransformerGen"
_basedir = join("models", "vgfs", "visual_transformer")
makedirs(_basedir, exist_ok=True)
_path = join(_basedir, f"{vgf_name}_{model_id}.pkl")
with open(_path, "wb") as f:
pickle.dump(self, f)
return self
def __str__(self):
str = f"[Visual Transformer VGF (v)]\n- model_name: {self.model_name}\n- batch_size: {self.batch_size}\n- batch_size_eval: {self.batch_size_eval}\n- lr: {self.lr}\n- epochs: {self.epochs}\n- device: {self.device}\n- print_steps: {self.print_steps}\n- evaluate_step: {self.evaluate_step}\n- patience: {self.patience}\n- probabilistic: {self.probabilistic}\n"
return str
def get_config(self):
return {"visual_trf": super().get_config()}