wcag_AI_validation/scripts/hf_gemma3_finetuning_wcag_d...

476 lines
16 KiB
Python

from huggingface_hub import login
import os
import gc
import subprocess
from pathlib import Path
from huggingface_hub import snapshot_download
os.environ['HF_HOME'] = './cache_huggingface' # or just "." for directly in current folder
#os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
# Login into Hugging Face Hub
hf_token = "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"#userdata.get('gemma3') # If you are running inside a Google Colab
print("Logging into Hugging Face Hub...")
login(hf_token)
print("Logged in.")
from datasets import load_dataset
from PIL import Image
# System message for the assistant
system_message = "You are a web accessibility evaluation tool. Your task is to evaluate if alterative text for images on webpages are appropriate according to WCAG guidelines."
# User prompt that combines the user query and the schema
user_prompt = """Create the most appropriate new alt-text given the image, the <HTML context>, and the current <alt-text>. Keep this within 30 words. Use the same language as the original alt-text.
Only return the new alt-text.
<alt-text>
{alttext}
</alt-text>
<HTML context>
{HTML_context}
</HTML context>
"""
def download_hf_model(model_id, output_dir="./hf_model"):
"""Download model from Hugging Face"""
print(f"Downloading {model_id} from Hugging Face...")
model_path = snapshot_download(
repo_id=model_id,
local_dir=output_dir,
local_dir_use_symlinks=False
)
print(f"Model downloaded to: {model_path}")
return model_path
def convert_to_gguf(model_path, output_path="./model.gguf"):
"""
Convert model to GGUF format using llama.cpp
Note: You need llama.cpp installed and convert.py script
Clone from: https://github.com/ggerganov/llama.cpp
"""
print("Converting to GGUF format...")
# This assumes you have llama.cpp cloned and convert.py available
# Adjust the path to your llama.cpp installation
convert_script = "./llama.cpp/convert_hf_to_gguf.py" # Path to llama.cpp convert.py
cmd = [
"python", convert_script,
model_path,
"--outfile", output_path,
"--outtype", "f16" # Use f16 for better quality, q4_0 for smaller size
]
try:
subprocess.run(cmd, check=True)
print(f"GGUF model created: {output_path}")
except FileNotFoundError:
print("Error: llama.cpp convert.py not found.")
print("Please clone llama.cpp: git clone https://github.com/ggerganov/llama.cpp")
return None
return output_path
def create_modelfile(model_name, gguf_path, template=None):
"""Create Ollama Modelfile"""
modelfile_content = f"""FROM {gguf_path}
# Set parameters
PARAMETER temperature 0.7
PARAMETER top_p 0.9
PARAMETER top_k 40
# Set the prompt template (adjust based on your model)
TEMPLATE """
if template:
modelfile_content += f'"""{template}"""'
else:
# Default template for chat models
modelfile_content += '''"""{{ if .System }}System: {{ .System }}
{{ end }}{{ if .Prompt }}User: {{ .Prompt }}
{{ end }}Assistant: """'''
modelfile_path = model_name + "Modelfile"
with open(modelfile_path, "w") as f:
f.write(modelfile_content)
print(f"Modelfile created: {modelfile_path}")
return modelfile_path
# NB: inferenza fatta con input immagine e i due campi testuali (e stessa instruction del finetuning)
def generate_description(dataset, model, processor):
print("Generating description...")
# Convert sample into messages and then apply the chat template
"""messages = [
{"role": "system", "content": [{"type": "text", "text": system_message}]},
{"role": "user", "content": [
{"type": "image","image": sample["image"]},
{"type": "text", "text": user_prompt.format(product=sample["product_name"], category=sample["category"])},
]},
]"""
### prendo il primo elemento come test
#image_inputs=dataset[0]["image"]#non è una lista ma per il resto è uguale a sotto
#print("image_inputs_pre:", image_inputs)
format_data_example=format_data(dataset[0])
messages=format_data_example["messages"][0:2]# non gli passo la parte assistant (la risposta attesa) come fa nell'esempio HF
print("User message:", messages)
text = processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
# Process the image and text
image_inputs = process_vision_info(messages)# converte immagine in rgb anche se sembra lo faccia già sopra nel sample .convert("RGB")
print("image_inputs:", image_inputs)
# Tokenize the text and process the images
inputs = processor(
text=[text],
images=image_inputs,
padding=True,
return_tensors="pt",
)
# Move the inputs to the device
inputs = inputs.to(model.device)
# Generate the output
stop_token_ids = [processor.tokenizer.eos_token_id, processor.tokenizer.convert_tokens_to_ids("<end_of_turn>")]
generated_ids = model.generate(**inputs, max_new_tokens=256, top_p=1.0, do_sample=True, temperature=0.8, eos_token_id=stop_token_ids, disable_compile=True)
# Trim the generation and decode the output to text
generated_ids_trimmed = [out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)]
output_text = processor.batch_decode(
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
)
return output_text[0]
# Convert dataset to OAI messages
def format_data(sample):
return {
"messages": [
{
"role": "system",
"content": [{"type": "text", "text": system_message}],
},
{
"role": "user",
"content": [
{
"type": "text",
"text": user_prompt.format(
HTML_context=sample["html_context"],
alttext=sample["alt_text"],
#accessibility_expert_alt_text_assessment=sample["original_alt_text_assessment"],
#accessibility_expert_alt_text_comments=sample["evaluation_result"]
),
},
{
"type": "image",
"image": sample["image"].convert("RGB"), #.convert("RGB") necessario??
},
],
},
{
"role": "assistant",
"content": [{"type": "text", "text": sample["new_alt_text"]}],#vedi ruolo assistente per la risposta aspettata
},
],
}
def process_vision_info(messages: list[dict]) -> list[Image.Image]:
print("Processing vision info...")
image_inputs = []
# Iterate through each conversation
for msg in messages:
# Get content (ensure it's a list)
content = msg.get("content", [])
if not isinstance(content, list):
content = [content]
# Check each content element for images
for element in content:
if isinstance(element, dict) and (
"image" in element or element.get("type") == "image"
):
# Get the image and convert to RGB
if "image" in element:
image = element["image"]
else:
image = element
image_inputs.append(image.convert("RGB"))#converte in rgb !
return image_inputs
print("Loading dataset...")
# Load dataset from the hub
#dataset = load_dataset("philschmid/amazon-product-descriptions-vlm", split="train",cache_dir="./dataset_cache")
dataset = load_dataset("nicolaleo/LLM-alt-text-assessment", split="train",cache_dir="./dataset_cache")
from copy import deepcopy
dataset_copy=deepcopy(dataset)
# Convert dataset to OAI messages
# need to use list comprehension to keep Pil.Image type, .mape convert image to bytes
dataset = [format_data(sample) for sample in dataset]
print(dataset[0]["messages"])
import torch
torch.cuda.get_device_capability()
print("Freeing up memory...")
torch.cuda.empty_cache()
gc.collect()
# Get free memory in bytes
free_memory = torch.cuda.mem_get_info()[0]
total_memory = torch.cuda.mem_get_info()[1]
# Convert to GB for readability
free_gb = free_memory / (1024**3)
total_gb = total_memory / (1024**3)
print(f"Free: {free_gb:.2f} GB / Total: {total_gb:.2f} GB")
from transformers import AutoProcessor, AutoModelForImageTextToText, BitsAndBytesConfig
# Hugging Face model id
model_id = "google/gemma-3-4b-it"#"google/gemma-3-4b-pt"#"google/gemma-3-4b-pt" # or `google/gemma-3-12b-pt`, `google/gemma-3-27-pt`
# Check if GPU benefits from bfloat16
#if torch.cuda.get_device_capability()[0] < 8:
# raise ValueError("GPU does not support bfloat16, please use a GPU that supports bfloat16.")
# Define model init arguments
model_kwargs = dict(
attn_implementation="eager", # Use "flash_attention_2" when running on Ampere or newer GPU
torch_dtype=torch.bfloat16,#torch.float16,#torch.bfloat16, # What torch dtype to use, defaults to auto
device_map="auto", # Let torch decide how to load the model
)
# BitsAndBytesConfig int-4 config
model_kwargs["quantization_config"] = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=model_kwargs["torch_dtype"],
bnb_4bit_quant_storage=model_kwargs["torch_dtype"],
)
# Load model and tokenizer
#model = AutoModelForImageTextToText.from_pretrained(model_id, **model_kwargs)
#processor = AutoProcessor.from_pretrained("google/gemma-3-4b-it")
# Set the cache directory to current folder
cache_dir = "./model_cache" # or just "." for directly in current folder
print("Loading model... This may take a while.")
model = AutoModelForImageTextToText.from_pretrained(# versione quantizzata 4bit
model_id,
cache_dir=cache_dir,
**model_kwargs
)
print("Model loaded.")
proc_cache_dir = "./proc_cache"
print("Loading processor...")
processor = AutoProcessor.from_pretrained(
"google/gemma-3-4b-it",#model_id, # nel file originale prende -it e non -pt (cambia poco comunque)
cache_dir=proc_cache_dir
)
print("Processor loaded.")
print("testing the loaded model...")
# generate the description
description = generate_description(dataset_copy, model, processor)
print("text generated:",description)
# Download and save to current folder
print("Saving model and processor locally...")
save_path = "./original_local_model_"+model_id.replace("/", "_")
model.save_pretrained(save_path)
processor.save_pretrained(save_path)
print("Model and processor saved.")
""" # la convesrione in ollama funziona solo se fatta su modello non quantizzato (da capire se si può fare su modello 4bit)
print("Converting and importing model to Ollama...")
# Step 1: Download from Hugging Face
model_path= "./original_local_model_ollama"
model_path = download_hf_model(model_id,output_dir=model_path)
# Step 2: Convert to GGUF (requires llama.cpp)
gguf_path = convert_to_gguf(model_path, "./gemma.gguf")
if gguf_path:
# Step 3: Create Modelfile
OLLAMA_MODEL_NAME = "gemma3-wcag"
modelfile = create_modelfile(OLLAMA_MODEL_NAME, gguf_path)
"""
from peft import LoraConfig
peft_config = LoraConfig(
lora_alpha=16,
lora_dropout=0.05,
r=16,
bias="none",
target_modules="all-linear",
task_type="CAUSAL_LM",
#modules_to_save=[ #quello che mi prendeva memoria in più
# "lm_head",
# "embed_tokens",
#],
)
from trl import SFTConfig
args = SFTConfig(
output_dir="./gemma-finetuned-wcag_"+model_id.replace("/", "_"), # directory to save and repository id
num_train_epochs=1, # number of training epochs
per_device_train_batch_size=1, # batch size per device during training
gradient_accumulation_steps=4, # number of steps before performing a backward/update pass
gradient_checkpointing=True, # use gradient checkpointing to save memory
optim="adamw_torch_fused", # use fused adamw optimizer
logging_steps=5, # log every 5 steps
save_strategy="epoch", # save checkpoint every epoch
learning_rate=2e-4, # learning rate, based on QLoRA paper
bf16=True,#False,#True, # use bfloat16 precision
max_grad_norm=0.3, # max gradient norm based on QLoRA paper
warmup_ratio=0.03, # warmup ratio based on QLoRA paper
lr_scheduler_type="constant", # use constant learning rate scheduler
push_to_hub=True, # push model to hub
report_to="tensorboard", # report metrics to tensorboard
gradient_checkpointing_kwargs={
"use_reentrant": False
}, # use reentrant checkpointing
dataset_text_field="", # need a dummy field for collator
dataset_kwargs={"skip_prepare_dataset": True}, # important for collator
)
args.remove_unused_columns = False # important for collator
# Create a data collator to encode text and image pairs
def collate_fn(examples):
texts = []
images = []
for example in examples:
image_inputs = process_vision_info(example["messages"])
text = processor.apply_chat_template(
example["messages"], add_generation_prompt=False, tokenize=False
)
texts.append(text.strip())
images.append(image_inputs)
# Tokenize the texts and process the images
batch = processor(text=texts, images=images, return_tensors="pt", padding=True)
# The labels are the input_ids, and we mask the padding tokens and image tokens in the loss computation
labels = batch["input_ids"].clone()
# Mask image tokens
image_token_id = [
processor.tokenizer.convert_tokens_to_ids(
processor.tokenizer.special_tokens_map["boi_token"]
)
]
# Mask tokens for not being used in the loss computation
labels[labels == processor.tokenizer.pad_token_id] = -100
labels[labels == image_token_id] = -100
labels[labels == 262144] = -100
batch["labels"] = labels
return batch
from trl import SFTTrainer
trainer = SFTTrainer(
model=model,
args=args,
train_dataset=dataset,
peft_config=peft_config,
processing_class=processor,
data_collator=collate_fn,
)
print("Starting training...")
# Start training, the model will be automatically saved to the Hub and the output directory
trainer.train()
print("Training completed.")
# Save the final model again to the Hugging Face Hub
trainer.save_model()# non ho capito questo cosa fa
# free the memory again
del model
del trainer
torch.cuda.empty_cache()
from peft import PeftModel
# Load Model base model
model = AutoModelForImageTextToText.from_pretrained(model_id, low_cpu_mem_usage=True,cache_dir=cache_dir)
# Merge LoRA and base model and save
peft_model = PeftModel.from_pretrained(model, args.output_dir)
merged_model = peft_model.merge_and_unload()
merged_model.save_pretrained("merged_model_"+model_id.replace("/", "_"), safe_serialization=True, max_shard_size="2GB")
processor = AutoProcessor.from_pretrained(args.output_dir)
processor.save_pretrained("merged_model_"+model_id.replace("/", "_"))
print("Loading merged model for inference...")
# Load Model with PEFT adapter
model = AutoModelForImageTextToText.from_pretrained(
args.output_dir,# dovrebbe essere "./merged_model" e non ./gemma-finetuned-wcag. infatti nel test uso ./merged_model
device_map="auto",
torch_dtype=torch.bfloat16,
attn_implementation="eager",
)
processor = AutoProcessor.from_pretrained(args.output_dir)
print("testing the merged model...")
"""
import requests
from PIL import Image
# Test sample with Product Name, Category and Image
sample = {
"product_name": "Hasbro Marvel Avengers-Serie Marvel Assemble Titan-Held, Iron Man, 30,5 cm Actionfigur",
"category": "Toys & Games | Toy Figures & Playsets | Action Figures",
"image": Image.open(requests.get("https://m.media-amazon.com/images/I/81+7Up7IWyL._AC_SY300_SX300_.jpg", stream=True).raw).convert("RGB")
}
"""
# generate the description
description = generate_description(dataset_copy, model, processor)
print("text generated:",description)