476 lines
16 KiB
Python
476 lines
16 KiB
Python
from huggingface_hub import login
|
|
import os
|
|
import gc
|
|
import subprocess
|
|
from pathlib import Path
|
|
from huggingface_hub import snapshot_download
|
|
|
|
os.environ['HF_HOME'] = './cache_huggingface' # or just "." for directly in current folder
|
|
#os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
|
|
|
|
# Login into Hugging Face Hub
|
|
hf_token = "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"#userdata.get('gemma3') # If you are running inside a Google Colab
|
|
print("Logging into Hugging Face Hub...")
|
|
login(hf_token)
|
|
print("Logged in.")
|
|
from datasets import load_dataset
|
|
from PIL import Image
|
|
|
|
# System message for the assistant
|
|
system_message = "You are a web accessibility evaluation tool. Your task is to evaluate if alterative text for images on webpages are appropriate according to WCAG guidelines."
|
|
|
|
# User prompt that combines the user query and the schema
|
|
user_prompt = """Create the most appropriate new alt-text given the image, the <HTML context>, and the current <alt-text>. Keep this within 30 words. Use the same language as the original alt-text.
|
|
Only return the new alt-text.
|
|
|
|
<alt-text>
|
|
{alttext}
|
|
</alt-text>
|
|
|
|
<HTML context>
|
|
{HTML_context}
|
|
</HTML context>
|
|
|
|
"""
|
|
|
|
def download_hf_model(model_id, output_dir="./hf_model"):
|
|
"""Download model from Hugging Face"""
|
|
print(f"Downloading {model_id} from Hugging Face...")
|
|
model_path = snapshot_download(
|
|
repo_id=model_id,
|
|
local_dir=output_dir,
|
|
local_dir_use_symlinks=False
|
|
)
|
|
print(f"Model downloaded to: {model_path}")
|
|
return model_path
|
|
|
|
def convert_to_gguf(model_path, output_path="./model.gguf"):
|
|
"""
|
|
Convert model to GGUF format using llama.cpp
|
|
|
|
Note: You need llama.cpp installed and convert.py script
|
|
Clone from: https://github.com/ggerganov/llama.cpp
|
|
"""
|
|
print("Converting to GGUF format...")
|
|
|
|
# This assumes you have llama.cpp cloned and convert.py available
|
|
# Adjust the path to your llama.cpp installation
|
|
convert_script = "./llama.cpp/convert_hf_to_gguf.py" # Path to llama.cpp convert.py
|
|
|
|
cmd = [
|
|
"python", convert_script,
|
|
model_path,
|
|
"--outfile", output_path,
|
|
"--outtype", "f16" # Use f16 for better quality, q4_0 for smaller size
|
|
]
|
|
|
|
try:
|
|
subprocess.run(cmd, check=True)
|
|
print(f"GGUF model created: {output_path}")
|
|
except FileNotFoundError:
|
|
print("Error: llama.cpp convert.py not found.")
|
|
print("Please clone llama.cpp: git clone https://github.com/ggerganov/llama.cpp")
|
|
return None
|
|
|
|
return output_path
|
|
|
|
def create_modelfile(model_name, gguf_path, template=None):
|
|
"""Create Ollama Modelfile"""
|
|
modelfile_content = f"""FROM {gguf_path}
|
|
|
|
# Set parameters
|
|
PARAMETER temperature 0.7
|
|
PARAMETER top_p 0.9
|
|
PARAMETER top_k 40
|
|
|
|
# Set the prompt template (adjust based on your model)
|
|
TEMPLATE """
|
|
|
|
if template:
|
|
modelfile_content += f'"""{template}"""'
|
|
else:
|
|
# Default template for chat models
|
|
modelfile_content += '''"""{{ if .System }}System: {{ .System }}
|
|
{{ end }}{{ if .Prompt }}User: {{ .Prompt }}
|
|
{{ end }}Assistant: """'''
|
|
|
|
modelfile_path = model_name + "Modelfile"
|
|
with open(modelfile_path, "w") as f:
|
|
f.write(modelfile_content)
|
|
|
|
print(f"Modelfile created: {modelfile_path}")
|
|
return modelfile_path
|
|
|
|
|
|
|
|
|
|
# NB: inferenza fatta con input immagine e i due campi testuali (e stessa instruction del finetuning)
|
|
def generate_description(dataset, model, processor):
|
|
print("Generating description...")
|
|
# Convert sample into messages and then apply the chat template
|
|
"""messages = [
|
|
{"role": "system", "content": [{"type": "text", "text": system_message}]},
|
|
{"role": "user", "content": [
|
|
{"type": "image","image": sample["image"]},
|
|
{"type": "text", "text": user_prompt.format(product=sample["product_name"], category=sample["category"])},
|
|
]},
|
|
]"""
|
|
|
|
### prendo il primo elemento come test
|
|
#image_inputs=dataset[0]["image"]#non è una lista ma per il resto è uguale a sotto
|
|
#print("image_inputs_pre:", image_inputs)
|
|
format_data_example=format_data(dataset[0])
|
|
messages=format_data_example["messages"][0:2]# non gli passo la parte assistant (la risposta attesa) come fa nell'esempio HF
|
|
print("User message:", messages)
|
|
text = processor.apply_chat_template(
|
|
messages, tokenize=False, add_generation_prompt=True
|
|
)
|
|
# Process the image and text
|
|
image_inputs = process_vision_info(messages)# converte immagine in rgb anche se sembra lo faccia già sopra nel sample .convert("RGB")
|
|
print("image_inputs:", image_inputs)
|
|
|
|
# Tokenize the text and process the images
|
|
inputs = processor(
|
|
text=[text],
|
|
images=image_inputs,
|
|
padding=True,
|
|
return_tensors="pt",
|
|
)
|
|
# Move the inputs to the device
|
|
inputs = inputs.to(model.device)
|
|
|
|
# Generate the output
|
|
stop_token_ids = [processor.tokenizer.eos_token_id, processor.tokenizer.convert_tokens_to_ids("<end_of_turn>")]
|
|
generated_ids = model.generate(**inputs, max_new_tokens=256, top_p=1.0, do_sample=True, temperature=0.8, eos_token_id=stop_token_ids, disable_compile=True)
|
|
# Trim the generation and decode the output to text
|
|
generated_ids_trimmed = [out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)]
|
|
output_text = processor.batch_decode(
|
|
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
|
|
)
|
|
return output_text[0]
|
|
|
|
# Convert dataset to OAI messages
|
|
def format_data(sample):
|
|
return {
|
|
"messages": [
|
|
{
|
|
"role": "system",
|
|
"content": [{"type": "text", "text": system_message}],
|
|
},
|
|
{
|
|
"role": "user",
|
|
"content": [
|
|
{
|
|
"type": "text",
|
|
"text": user_prompt.format(
|
|
HTML_context=sample["html_context"],
|
|
alttext=sample["alt_text"],
|
|
#accessibility_expert_alt_text_assessment=sample["original_alt_text_assessment"],
|
|
#accessibility_expert_alt_text_comments=sample["evaluation_result"]
|
|
|
|
|
|
|
|
),
|
|
},
|
|
{
|
|
"type": "image",
|
|
"image": sample["image"].convert("RGB"), #.convert("RGB") necessario??
|
|
},
|
|
],
|
|
},
|
|
{
|
|
"role": "assistant",
|
|
"content": [{"type": "text", "text": sample["new_alt_text"]}],#vedi ruolo assistente per la risposta aspettata
|
|
},
|
|
],
|
|
}
|
|
|
|
|
|
|
|
def process_vision_info(messages: list[dict]) -> list[Image.Image]:
|
|
print("Processing vision info...")
|
|
image_inputs = []
|
|
# Iterate through each conversation
|
|
for msg in messages:
|
|
# Get content (ensure it's a list)
|
|
content = msg.get("content", [])
|
|
if not isinstance(content, list):
|
|
content = [content]
|
|
|
|
# Check each content element for images
|
|
for element in content:
|
|
if isinstance(element, dict) and (
|
|
"image" in element or element.get("type") == "image"
|
|
):
|
|
# Get the image and convert to RGB
|
|
if "image" in element:
|
|
image = element["image"]
|
|
else:
|
|
image = element
|
|
image_inputs.append(image.convert("RGB"))#converte in rgb !
|
|
return image_inputs
|
|
|
|
print("Loading dataset...")
|
|
# Load dataset from the hub
|
|
#dataset = load_dataset("philschmid/amazon-product-descriptions-vlm", split="train",cache_dir="./dataset_cache")
|
|
dataset = load_dataset("nicolaleo/LLM-alt-text-assessment", split="train",cache_dir="./dataset_cache")
|
|
|
|
|
|
from copy import deepcopy
|
|
|
|
dataset_copy=deepcopy(dataset)
|
|
|
|
|
|
|
|
# Convert dataset to OAI messages
|
|
# need to use list comprehension to keep Pil.Image type, .mape convert image to bytes
|
|
dataset = [format_data(sample) for sample in dataset]
|
|
|
|
|
|
print(dataset[0]["messages"])
|
|
|
|
import torch
|
|
torch.cuda.get_device_capability()
|
|
|
|
print("Freeing up memory...")
|
|
torch.cuda.empty_cache()
|
|
gc.collect()
|
|
|
|
# Get free memory in bytes
|
|
free_memory = torch.cuda.mem_get_info()[0]
|
|
total_memory = torch.cuda.mem_get_info()[1]
|
|
|
|
# Convert to GB for readability
|
|
free_gb = free_memory / (1024**3)
|
|
total_gb = total_memory / (1024**3)
|
|
|
|
print(f"Free: {free_gb:.2f} GB / Total: {total_gb:.2f} GB")
|
|
|
|
from transformers import AutoProcessor, AutoModelForImageTextToText, BitsAndBytesConfig
|
|
|
|
# Hugging Face model id
|
|
model_id = "google/gemma-3-4b-it"#"google/gemma-3-4b-pt"#"google/gemma-3-4b-pt" # or `google/gemma-3-12b-pt`, `google/gemma-3-27-pt`
|
|
|
|
# Check if GPU benefits from bfloat16
|
|
#if torch.cuda.get_device_capability()[0] < 8:
|
|
# raise ValueError("GPU does not support bfloat16, please use a GPU that supports bfloat16.")
|
|
|
|
# Define model init arguments
|
|
model_kwargs = dict(
|
|
attn_implementation="eager", # Use "flash_attention_2" when running on Ampere or newer GPU
|
|
torch_dtype=torch.bfloat16,#torch.float16,#torch.bfloat16, # What torch dtype to use, defaults to auto
|
|
device_map="auto", # Let torch decide how to load the model
|
|
|
|
)
|
|
|
|
# BitsAndBytesConfig int-4 config
|
|
model_kwargs["quantization_config"] = BitsAndBytesConfig(
|
|
load_in_4bit=True,
|
|
bnb_4bit_use_double_quant=True,
|
|
bnb_4bit_quant_type="nf4",
|
|
bnb_4bit_compute_dtype=model_kwargs["torch_dtype"],
|
|
bnb_4bit_quant_storage=model_kwargs["torch_dtype"],
|
|
)
|
|
|
|
# Load model and tokenizer
|
|
#model = AutoModelForImageTextToText.from_pretrained(model_id, **model_kwargs)
|
|
#processor = AutoProcessor.from_pretrained("google/gemma-3-4b-it")
|
|
|
|
|
|
|
|
|
|
# Set the cache directory to current folder
|
|
cache_dir = "./model_cache" # or just "." for directly in current folder
|
|
|
|
print("Loading model... This may take a while.")
|
|
model = AutoModelForImageTextToText.from_pretrained(# versione quantizzata 4bit
|
|
model_id,
|
|
cache_dir=cache_dir,
|
|
**model_kwargs
|
|
)
|
|
print("Model loaded.")
|
|
|
|
|
|
proc_cache_dir = "./proc_cache"
|
|
print("Loading processor...")
|
|
processor = AutoProcessor.from_pretrained(
|
|
"google/gemma-3-4b-it",#model_id, # nel file originale prende -it e non -pt (cambia poco comunque)
|
|
cache_dir=proc_cache_dir
|
|
)
|
|
print("Processor loaded.")
|
|
|
|
|
|
print("testing the loaded model...")
|
|
# generate the description
|
|
description = generate_description(dataset_copy, model, processor)
|
|
print("text generated:",description)
|
|
|
|
|
|
# Download and save to current folder
|
|
print("Saving model and processor locally...")
|
|
save_path = "./original_local_model_"+model_id.replace("/", "_")
|
|
model.save_pretrained(save_path)
|
|
processor.save_pretrained(save_path)
|
|
print("Model and processor saved.")
|
|
|
|
|
|
""" # la convesrione in ollama funziona solo se fatta su modello non quantizzato (da capire se si può fare su modello 4bit)
|
|
print("Converting and importing model to Ollama...")
|
|
# Step 1: Download from Hugging Face
|
|
model_path= "./original_local_model_ollama"
|
|
model_path = download_hf_model(model_id,output_dir=model_path)
|
|
|
|
# Step 2: Convert to GGUF (requires llama.cpp)
|
|
gguf_path = convert_to_gguf(model_path, "./gemma.gguf")
|
|
|
|
if gguf_path:
|
|
# Step 3: Create Modelfile
|
|
OLLAMA_MODEL_NAME = "gemma3-wcag"
|
|
modelfile = create_modelfile(OLLAMA_MODEL_NAME, gguf_path)
|
|
|
|
"""
|
|
|
|
|
|
|
|
from peft import LoraConfig
|
|
|
|
peft_config = LoraConfig(
|
|
lora_alpha=16,
|
|
lora_dropout=0.05,
|
|
r=16,
|
|
bias="none",
|
|
target_modules="all-linear",
|
|
task_type="CAUSAL_LM",
|
|
#modules_to_save=[ #quello che mi prendeva memoria in più
|
|
# "lm_head",
|
|
# "embed_tokens",
|
|
#],
|
|
)
|
|
|
|
from trl import SFTConfig
|
|
|
|
args = SFTConfig(
|
|
output_dir="./gemma-finetuned-wcag_"+model_id.replace("/", "_"), # directory to save and repository id
|
|
num_train_epochs=1, # number of training epochs
|
|
per_device_train_batch_size=1, # batch size per device during training
|
|
gradient_accumulation_steps=4, # number of steps before performing a backward/update pass
|
|
gradient_checkpointing=True, # use gradient checkpointing to save memory
|
|
optim="adamw_torch_fused", # use fused adamw optimizer
|
|
logging_steps=5, # log every 5 steps
|
|
save_strategy="epoch", # save checkpoint every epoch
|
|
learning_rate=2e-4, # learning rate, based on QLoRA paper
|
|
bf16=True,#False,#True, # use bfloat16 precision
|
|
max_grad_norm=0.3, # max gradient norm based on QLoRA paper
|
|
warmup_ratio=0.03, # warmup ratio based on QLoRA paper
|
|
lr_scheduler_type="constant", # use constant learning rate scheduler
|
|
push_to_hub=True, # push model to hub
|
|
report_to="tensorboard", # report metrics to tensorboard
|
|
gradient_checkpointing_kwargs={
|
|
"use_reentrant": False
|
|
}, # use reentrant checkpointing
|
|
dataset_text_field="", # need a dummy field for collator
|
|
dataset_kwargs={"skip_prepare_dataset": True}, # important for collator
|
|
)
|
|
args.remove_unused_columns = False # important for collator
|
|
|
|
# Create a data collator to encode text and image pairs
|
|
def collate_fn(examples):
|
|
texts = []
|
|
images = []
|
|
for example in examples:
|
|
image_inputs = process_vision_info(example["messages"])
|
|
text = processor.apply_chat_template(
|
|
example["messages"], add_generation_prompt=False, tokenize=False
|
|
)
|
|
texts.append(text.strip())
|
|
images.append(image_inputs)
|
|
|
|
# Tokenize the texts and process the images
|
|
batch = processor(text=texts, images=images, return_tensors="pt", padding=True)
|
|
|
|
# The labels are the input_ids, and we mask the padding tokens and image tokens in the loss computation
|
|
labels = batch["input_ids"].clone()
|
|
|
|
# Mask image tokens
|
|
image_token_id = [
|
|
processor.tokenizer.convert_tokens_to_ids(
|
|
processor.tokenizer.special_tokens_map["boi_token"]
|
|
)
|
|
]
|
|
# Mask tokens for not being used in the loss computation
|
|
labels[labels == processor.tokenizer.pad_token_id] = -100
|
|
labels[labels == image_token_id] = -100
|
|
labels[labels == 262144] = -100
|
|
|
|
batch["labels"] = labels
|
|
return batch
|
|
|
|
from trl import SFTTrainer
|
|
|
|
trainer = SFTTrainer(
|
|
model=model,
|
|
args=args,
|
|
train_dataset=dataset,
|
|
peft_config=peft_config,
|
|
processing_class=processor,
|
|
data_collator=collate_fn,
|
|
)
|
|
|
|
print("Starting training...")
|
|
# Start training, the model will be automatically saved to the Hub and the output directory
|
|
trainer.train()
|
|
|
|
print("Training completed.")
|
|
# Save the final model again to the Hugging Face Hub
|
|
trainer.save_model()# non ho capito questo cosa fa
|
|
|
|
# free the memory again
|
|
del model
|
|
del trainer
|
|
torch.cuda.empty_cache()
|
|
|
|
from peft import PeftModel
|
|
|
|
# Load Model base model
|
|
model = AutoModelForImageTextToText.from_pretrained(model_id, low_cpu_mem_usage=True,cache_dir=cache_dir)
|
|
|
|
# Merge LoRA and base model and save
|
|
peft_model = PeftModel.from_pretrained(model, args.output_dir)
|
|
merged_model = peft_model.merge_and_unload()
|
|
merged_model.save_pretrained("merged_model_"+model_id.replace("/", "_"), safe_serialization=True, max_shard_size="2GB")
|
|
|
|
processor = AutoProcessor.from_pretrained(args.output_dir)
|
|
processor.save_pretrained("merged_model_"+model_id.replace("/", "_"))
|
|
|
|
|
|
print("Loading merged model for inference...")
|
|
# Load Model with PEFT adapter
|
|
model = AutoModelForImageTextToText.from_pretrained(
|
|
args.output_dir,# dovrebbe essere "./merged_model" e non ./gemma-finetuned-wcag. infatti nel test uso ./merged_model
|
|
device_map="auto",
|
|
torch_dtype=torch.bfloat16,
|
|
attn_implementation="eager",
|
|
)
|
|
processor = AutoProcessor.from_pretrained(args.output_dir)
|
|
|
|
|
|
print("testing the merged model...")
|
|
|
|
|
|
"""
|
|
import requests
|
|
from PIL import Image
|
|
|
|
# Test sample with Product Name, Category and Image
|
|
sample = {
|
|
"product_name": "Hasbro Marvel Avengers-Serie Marvel Assemble Titan-Held, Iron Man, 30,5 cm Actionfigur",
|
|
"category": "Toys & Games | Toy Figures & Playsets | Action Figures",
|
|
"image": Image.open(requests.get("https://m.media-amazon.com/images/I/81+7Up7IWyL._AC_SY300_SX300_.jpg", stream=True).raw).convert("RGB")
|
|
}
|
|
"""
|
|
|
|
|
|
|
|
# generate the description
|
|
description = generate_description(dataset_copy, model, processor)
|
|
print("text generated:",description) |