wcag_AI_validation/scripts/test_finetuned_model.py

280 lines
9.4 KiB
Python

from huggingface_hub import login
import os
from datasets import load_dataset
from PIL import Image
import torch
from transformers import AutoProcessor, AutoModelForImageTextToText, BitsAndBytesConfig
import gc
# System message for the assistant
system_message = "You are a web accessibility evaluation tool. Your task is to evaluate if alterative text for images on webpages are appropriate according to WCAG guidelines."
# User prompt that combines the user query and the schema
user_prompt = """Create the most appropriate new alt-text given the image, the <HTML context>, and the current <alt-text>. Keep this within 30 words. Use the same language as the original alt-text.
Only return the new alt-text.
<alt-text>
{alttext}
</alt-text>
<HTML context>
{HTML_context}
</HTML context>
"""
def process_vision_info(messages: list[dict]) -> list[Image.Image]:
#print("Processing vision info...")
image_inputs = []
# Iterate through each conversation
for msg in messages:
# Get content (ensure it's a list)
content = msg.get("content", [])
if not isinstance(content, list):
content = [content]
# Check each content element for images
for element in content:
if isinstance(element, dict) and (
"image" in element or element.get("type") == "image"
):
# Get the image and convert to RGB
if "image" in element:
image = element["image"]
else:
image = element
image_inputs.append(image.convert("RGB"))#converte in rgb !
return image_inputs
def format_data(sample):
return {
"messages": [
{
"role": "system",
"content": [{"type": "text", "text": system_message}],
},
{
"role": "user",
"content": [
{
"type": "text",
"text": user_prompt.format(
HTML_context=sample["html_context"],
alttext=sample["alt_text"],
#accessibility_expert_alt_text_assessment=sample["original_alt_text_assessment"],
#accessibility_expert_alt_text_comments=sample["evaluation_result"]
),
},
{
"type": "image",
"image": sample["image"].convert("RGB"), #.convert("RGB") necessario??
},
],
},
{
"role": "assistant",
"content": [{"type": "text", "text": sample["new_alt_text"]}],#vedi ruolo assistente per la risposta aspettata
},
],
}
def generate_description(dataset, model, processor,example_idx=0):
print("Generating description...")
# Convert sample into messages and then apply the chat template
"""messages = [
{"role": "system", "content": [{"type": "text", "text": system_message}]},
{"role": "user", "content": [
{"type": "image","image": sample["image"]},
{"type": "text", "text": user_prompt.format(product=sample["product_name"], category=sample["category"])},
]},
]"""
### prendo il primo elemento come test
#image_inputs=dataset[0]["image"]#non è una lista ma per il resto è uguale a sotto
#print("image_inputs_pre:", image_inputs)
format_data_example=format_data(dataset[example_idx])
messages=format_data_example["messages"][0:2]# non gli passo la parte assistant (la risposta attesa) come fa nell'esempio HF
#print("User message:", messages)
text = processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
# Process the image and text
image_inputs = process_vision_info(messages)# converte immagine in rgb anche se sembra lo faccia già sopra nel sample .convert("RGB")
#print("image_inputs:", image_inputs)
# Tokenize the text and process the images
inputs = processor(
text=[text],
images=image_inputs,
padding=True,
return_tensors="pt",
)
# Move the inputs to the device
inputs = inputs.to(model.device)
# Generate the output
stop_token_ids = [processor.tokenizer.eos_token_id, processor.tokenizer.convert_tokens_to_ids("<end_of_turn>")]
generated_ids = model.generate(**inputs, max_new_tokens=256, top_p=1.0, do_sample=True, temperature=0.8, eos_token_id=stop_token_ids, disable_compile=True)
# Trim the generation and decode the output to text
generated_ids_trimmed = [out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)]
output_text = processor.batch_decode(
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
)
return output_text[0]
from peft import PeftModel
os.environ['HF_HOME'] = './cache_huggingface' # or just "." for directly in current folder
#os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
# Login into Hugging Face Hub
hf_token = "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"#userdata.get('gemma3') # If you are running inside a Google Colab
print("Logging into Hugging Face Hub...")
login(hf_token)
print("Logged in.")
model_id = "google/gemma-3-4b-it"
output_dir="./merged_model"#"./gemma-finetuned-wcag"
dataset = load_dataset("nicolaleo/LLM-alt-text-assessment", split="train",cache_dir="./dataset_cache")
from copy import deepcopy
dataset_copy=deepcopy(dataset)
cache_dir = "./model_cache"
proc_cache_dir = "./proc_cache"
model_kwargs = dict(
attn_implementation="eager", # Use "flash_attention_2" when running on Ampere or newer GPU
torch_dtype=torch.bfloat16,#torch.float16,#torch.bfloat16, # What torch dtype to use, defaults to auto
device_map="auto", # Let torch decide how to load the model
)
# BitsAndBytesConfig int-4 config
model_kwargs["quantization_config"] = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=model_kwargs["torch_dtype"],
bnb_4bit_quant_storage=model_kwargs["torch_dtype"],
)
print("Freeing up memory...")
torch.cuda.empty_cache()
gc.collect()
# Load Model base model
model = AutoModelForImageTextToText.from_pretrained(model_id,cache_dir=cache_dir)
print("Model loaded #1")
#print(model)
#load pre-trained processor
processor = AutoProcessor.from_pretrained(
"google/gemma-3-4b-it",#model_id, # nel file originale prende -it e non -pt (cambia poco comunque)
cache_dir=proc_cache_dir
)
print("Processor loaded #1")
print("testing the model #1...")
# generate the description
description = generate_description(dataset_copy, model, processor,example_idx=0)
print("-text generated 1:",description)
description = generate_description(dataset_copy, model, processor,example_idx=1)
print("-text generated 2:",description)
description = generate_description(dataset_copy, model, processor,example_idx=20)
print("-text generated 3:",description)
print("Freeing up memory...")
torch.cuda.empty_cache()
gc.collect()
del model
#load Model with 4bit quantization
model = AutoModelForImageTextToText.from_pretrained(model_id,cache_dir=cache_dir, **model_kwargs)
print("\n Model loaded #2 with 4bit quantization")
#print(model)
processor = AutoProcessor.from_pretrained(
"google/gemma-3-4b-it",#model_id, # nel file originale prende -it e non -pt (cambia poco comunque)
cache_dir=proc_cache_dir
)
print("Processor loaded #2")
print("testing the model #2 with 4bit quantization...")
# generate the description
description = generate_description(dataset_copy, model, processor,example_idx=0)
print("-text generated 1:",description)
description = generate_description(dataset_copy, model, processor,example_idx=1)
print("-text generated 2:",description)
description = generate_description(dataset_copy, model, processor,example_idx=20)
print("-text generated 3:",description)
"""
# Merge LoRA and base model and save
peft_model = PeftModel.from_pretrained(model, output_dir)
merged_model = peft_model.merge_and_unload()
merged_model.save_pretrained("merged_model", safe_serialization=True, max_shard_size="2GB")
processor = AutoProcessor.from_pretrained(output_dir)
processor.save_pretrained("merged_model")
print("Loading merged model for inference...")
# Load Model with PEFT adapter
model = AutoModelForImageTextToText.from_pretrained(
output_dir,
device_map="auto",
torch_dtype=torch.bfloat16,
attn_implementation="eager",
)
processor = AutoProcessor.from_pretrained(output_dir)
print("Model loaded #2")
print(model)
"""
print("Freeing up memory...")
torch.cuda.empty_cache()
gc.collect()
del model
# Load Model with PEFT adapter
model = AutoModelForImageTextToText.from_pretrained(
output_dir,
device_map="auto",
torch_dtype=torch.bfloat16,
attn_implementation="eager",
)
print("\n Model loaded #3")
processor = AutoProcessor.from_pretrained(output_dir)
print("Processor loaded #3")
#print(model)
print("testing the Merged model #3 ...")
#dataset = [format_data(sample) for sample in dataset]
# generate the description
description = generate_description(dataset_copy, model, processor,example_idx=0)
print("-text generated 1:",description)
description = generate_description(dataset_copy, model, processor,example_idx=1)
print("-text generated 2:",description)
description = generate_description(dataset_copy, model, processor,example_idx=20)
print("-text generated 3:",description)