280 lines
9.4 KiB
Python
280 lines
9.4 KiB
Python
from huggingface_hub import login
|
|
import os
|
|
from datasets import load_dataset
|
|
from PIL import Image
|
|
import torch
|
|
from transformers import AutoProcessor, AutoModelForImageTextToText, BitsAndBytesConfig
|
|
import gc
|
|
|
|
# System message for the assistant
|
|
system_message = "You are a web accessibility evaluation tool. Your task is to evaluate if alterative text for images on webpages are appropriate according to WCAG guidelines."
|
|
|
|
# User prompt that combines the user query and the schema
|
|
user_prompt = """Create the most appropriate new alt-text given the image, the <HTML context>, and the current <alt-text>. Keep this within 30 words. Use the same language as the original alt-text.
|
|
Only return the new alt-text.
|
|
|
|
<alt-text>
|
|
{alttext}
|
|
</alt-text>
|
|
|
|
<HTML context>
|
|
{HTML_context}
|
|
</HTML context>
|
|
|
|
"""
|
|
def process_vision_info(messages: list[dict]) -> list[Image.Image]:
|
|
#print("Processing vision info...")
|
|
image_inputs = []
|
|
# Iterate through each conversation
|
|
for msg in messages:
|
|
# Get content (ensure it's a list)
|
|
content = msg.get("content", [])
|
|
if not isinstance(content, list):
|
|
content = [content]
|
|
|
|
# Check each content element for images
|
|
for element in content:
|
|
if isinstance(element, dict) and (
|
|
"image" in element or element.get("type") == "image"
|
|
):
|
|
# Get the image and convert to RGB
|
|
if "image" in element:
|
|
image = element["image"]
|
|
else:
|
|
image = element
|
|
image_inputs.append(image.convert("RGB"))#converte in rgb !
|
|
return image_inputs
|
|
|
|
def format_data(sample):
|
|
return {
|
|
"messages": [
|
|
{
|
|
"role": "system",
|
|
"content": [{"type": "text", "text": system_message}],
|
|
},
|
|
{
|
|
"role": "user",
|
|
"content": [
|
|
{
|
|
"type": "text",
|
|
"text": user_prompt.format(
|
|
HTML_context=sample["html_context"],
|
|
alttext=sample["alt_text"],
|
|
#accessibility_expert_alt_text_assessment=sample["original_alt_text_assessment"],
|
|
#accessibility_expert_alt_text_comments=sample["evaluation_result"]
|
|
|
|
|
|
|
|
),
|
|
},
|
|
{
|
|
"type": "image",
|
|
"image": sample["image"].convert("RGB"), #.convert("RGB") necessario??
|
|
},
|
|
],
|
|
},
|
|
{
|
|
"role": "assistant",
|
|
"content": [{"type": "text", "text": sample["new_alt_text"]}],#vedi ruolo assistente per la risposta aspettata
|
|
},
|
|
],
|
|
}
|
|
|
|
def generate_description(dataset, model, processor,example_idx=0):
|
|
print("Generating description...")
|
|
# Convert sample into messages and then apply the chat template
|
|
"""messages = [
|
|
{"role": "system", "content": [{"type": "text", "text": system_message}]},
|
|
{"role": "user", "content": [
|
|
{"type": "image","image": sample["image"]},
|
|
{"type": "text", "text": user_prompt.format(product=sample["product_name"], category=sample["category"])},
|
|
]},
|
|
]"""
|
|
|
|
### prendo il primo elemento come test
|
|
#image_inputs=dataset[0]["image"]#non è una lista ma per il resto è uguale a sotto
|
|
#print("image_inputs_pre:", image_inputs)
|
|
format_data_example=format_data(dataset[example_idx])
|
|
messages=format_data_example["messages"][0:2]# non gli passo la parte assistant (la risposta attesa) come fa nell'esempio HF
|
|
#print("User message:", messages)
|
|
text = processor.apply_chat_template(
|
|
messages, tokenize=False, add_generation_prompt=True
|
|
)
|
|
# Process the image and text
|
|
image_inputs = process_vision_info(messages)# converte immagine in rgb anche se sembra lo faccia già sopra nel sample .convert("RGB")
|
|
#print("image_inputs:", image_inputs)
|
|
|
|
# Tokenize the text and process the images
|
|
inputs = processor(
|
|
text=[text],
|
|
images=image_inputs,
|
|
padding=True,
|
|
return_tensors="pt",
|
|
)
|
|
# Move the inputs to the device
|
|
inputs = inputs.to(model.device)
|
|
|
|
# Generate the output
|
|
stop_token_ids = [processor.tokenizer.eos_token_id, processor.tokenizer.convert_tokens_to_ids("<end_of_turn>")]
|
|
generated_ids = model.generate(**inputs, max_new_tokens=256, top_p=1.0, do_sample=True, temperature=0.8, eos_token_id=stop_token_ids, disable_compile=True)
|
|
# Trim the generation and decode the output to text
|
|
generated_ids_trimmed = [out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)]
|
|
output_text = processor.batch_decode(
|
|
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
|
|
)
|
|
return output_text[0]
|
|
|
|
from peft import PeftModel
|
|
|
|
|
|
|
|
os.environ['HF_HOME'] = './cache_huggingface' # or just "." for directly in current folder
|
|
#os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
|
|
|
|
# Login into Hugging Face Hub
|
|
hf_token = "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"#userdata.get('gemma3') # If you are running inside a Google Colab
|
|
print("Logging into Hugging Face Hub...")
|
|
login(hf_token)
|
|
print("Logged in.")
|
|
|
|
|
|
|
|
|
|
model_id = "google/gemma-3-4b-it"
|
|
output_dir="./merged_model"#"./gemma-finetuned-wcag"
|
|
|
|
|
|
dataset = load_dataset("nicolaleo/LLM-alt-text-assessment", split="train",cache_dir="./dataset_cache")
|
|
from copy import deepcopy
|
|
|
|
dataset_copy=deepcopy(dataset)
|
|
|
|
|
|
cache_dir = "./model_cache"
|
|
proc_cache_dir = "./proc_cache"
|
|
|
|
|
|
model_kwargs = dict(
|
|
attn_implementation="eager", # Use "flash_attention_2" when running on Ampere or newer GPU
|
|
torch_dtype=torch.bfloat16,#torch.float16,#torch.bfloat16, # What torch dtype to use, defaults to auto
|
|
device_map="auto", # Let torch decide how to load the model
|
|
|
|
)
|
|
|
|
# BitsAndBytesConfig int-4 config
|
|
model_kwargs["quantization_config"] = BitsAndBytesConfig(
|
|
load_in_4bit=True,
|
|
bnb_4bit_use_double_quant=True,
|
|
bnb_4bit_quant_type="nf4",
|
|
bnb_4bit_compute_dtype=model_kwargs["torch_dtype"],
|
|
bnb_4bit_quant_storage=model_kwargs["torch_dtype"],
|
|
)
|
|
|
|
|
|
print("Freeing up memory...")
|
|
torch.cuda.empty_cache()
|
|
gc.collect()
|
|
|
|
# Load Model base model
|
|
model = AutoModelForImageTextToText.from_pretrained(model_id,cache_dir=cache_dir)
|
|
print("Model loaded #1")
|
|
#print(model)
|
|
|
|
#load pre-trained processor
|
|
processor = AutoProcessor.from_pretrained(
|
|
"google/gemma-3-4b-it",#model_id, # nel file originale prende -it e non -pt (cambia poco comunque)
|
|
cache_dir=proc_cache_dir
|
|
)
|
|
print("Processor loaded #1")
|
|
|
|
print("testing the model #1...")
|
|
# generate the description
|
|
description = generate_description(dataset_copy, model, processor,example_idx=0)
|
|
print("-text generated 1:",description)
|
|
|
|
description = generate_description(dataset_copy, model, processor,example_idx=1)
|
|
print("-text generated 2:",description)
|
|
|
|
description = generate_description(dataset_copy, model, processor,example_idx=20)
|
|
print("-text generated 3:",description)
|
|
|
|
print("Freeing up memory...")
|
|
torch.cuda.empty_cache()
|
|
gc.collect()
|
|
del model
|
|
|
|
#load Model with 4bit quantization
|
|
model = AutoModelForImageTextToText.from_pretrained(model_id,cache_dir=cache_dir, **model_kwargs)
|
|
print("\n Model loaded #2 with 4bit quantization")
|
|
#print(model)
|
|
processor = AutoProcessor.from_pretrained(
|
|
"google/gemma-3-4b-it",#model_id, # nel file originale prende -it e non -pt (cambia poco comunque)
|
|
cache_dir=proc_cache_dir
|
|
)
|
|
print("Processor loaded #2")
|
|
|
|
print("testing the model #2 with 4bit quantization...")
|
|
# generate the description
|
|
description = generate_description(dataset_copy, model, processor,example_idx=0)
|
|
print("-text generated 1:",description)
|
|
|
|
description = generate_description(dataset_copy, model, processor,example_idx=1)
|
|
print("-text generated 2:",description)
|
|
|
|
description = generate_description(dataset_copy, model, processor,example_idx=20)
|
|
print("-text generated 3:",description)
|
|
|
|
"""
|
|
# Merge LoRA and base model and save
|
|
peft_model = PeftModel.from_pretrained(model, output_dir)
|
|
merged_model = peft_model.merge_and_unload()
|
|
merged_model.save_pretrained("merged_model", safe_serialization=True, max_shard_size="2GB")
|
|
|
|
processor = AutoProcessor.from_pretrained(output_dir)
|
|
processor.save_pretrained("merged_model")
|
|
|
|
|
|
print("Loading merged model for inference...")
|
|
# Load Model with PEFT adapter
|
|
model = AutoModelForImageTextToText.from_pretrained(
|
|
output_dir,
|
|
device_map="auto",
|
|
torch_dtype=torch.bfloat16,
|
|
attn_implementation="eager",
|
|
)
|
|
processor = AutoProcessor.from_pretrained(output_dir)
|
|
print("Model loaded #2")
|
|
print(model)
|
|
"""
|
|
|
|
print("Freeing up memory...")
|
|
torch.cuda.empty_cache()
|
|
gc.collect()
|
|
del model
|
|
# Load Model with PEFT adapter
|
|
model = AutoModelForImageTextToText.from_pretrained(
|
|
output_dir,
|
|
device_map="auto",
|
|
torch_dtype=torch.bfloat16,
|
|
attn_implementation="eager",
|
|
)
|
|
print("\n Model loaded #3")
|
|
processor = AutoProcessor.from_pretrained(output_dir)
|
|
print("Processor loaded #3")
|
|
#print(model)
|
|
|
|
|
|
print("testing the Merged model #3 ...")
|
|
|
|
|
|
#dataset = [format_data(sample) for sample in dataset]
|
|
|
|
# generate the description
|
|
description = generate_description(dataset_copy, model, processor,example_idx=0)
|
|
print("-text generated 1:",description)
|
|
|
|
description = generate_description(dataset_copy, model, processor,example_idx=1)
|
|
print("-text generated 2:",description)
|
|
|
|
description = generate_description(dataset_copy, model, processor,example_idx=20)
|
|
print("-text generated 3:",description) |