wcag_AI_validation/dependences/image_extractor.py

620 lines
25 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import asyncio
from playwright.async_api import async_playwright
from datetime import datetime, timezone
from urllib.parse import urljoin, urlparse
from typing import List, Dict, Optional
import json
import argparse
from dependences.utils import disclaim_bool_string, prepare_output_folder, create_folder
import requests
import os
import urllib.parse
from pathlib import Path
class ImageExtractor:
SUPPORTED_FORMATS = {"png", "jpeg", "jpg", "webp", "gif"}
def __init__(
self,
url: str,
context_levels: int = 5,
pixel_distance_threshold: int = 200,
number_of_images: int = 10,
save_images=True,
save_images_path="",
):
"""
Initialize the ImageExtractor.
Args:
url: The page URL to extract images from
context_levels: Number of parent/child levels to traverse for context (default=5)
pixel_distance_threshold: Maximum pixel distance for nearby text elements (default=200)
number_of_images: maximum number for the desired images
save_images: if save images
save_images_path: path to save images
"""
self.url = url
self.context_levels = context_levels
self.pixel_distance_threshold = pixel_distance_threshold
self.number_of_images = number_of_images
self.save_images = save_images
self.save_images_path = save_images_path
def _is_supported_format(self, img_url: str) -> bool:
"""Check if the image URL has a supported format."""
parsed = urlparse(img_url.lower())
path = parsed.path
# Check file extension
for fmt in self.SUPPORTED_FORMATS:
if path.endswith(f".{fmt}"):
return True
# Also check query parameters (e.g., format=jpeg)
return any(fmt in img_url.lower() for fmt in self.SUPPORTED_FORMATS)
async def _download_image(self, image_url, output_dir="images") -> None:
# Parse the URL to get the path without query parameters
parsed_url = urllib.parse.urlparse(image_url)
url_path = parsed_url.path
# Get the filename from the path
filename = url_path.split("/")[-1]
# Split filename and extension
if "." in filename:
image_name, ext = filename.rsplit(".", 1)
ext = ext.lower()
else:
image_name = filename
ext = "jpg"
# Validate extension
if ext not in ["jpg", "jpeg", "png", "gif", "webp"]:
ext = "jpg"
# Sanitize image name (remove special characters, limit length)
image_name = "".join(c for c in image_name if c.isalnum() or c in ("-", "_"))
image_name = image_name[:50] # Limit filename length
# If name is empty after sanitization, create a hash-based name
if not image_name:
import hashlib
image_name = hashlib.md5(image_url.encode()).hexdigest()[:16]
# Download the image
print("getting image url:", image_url)
print("getting image name:", image_name)
response = requests.get(image_url, timeout=10)
response.raise_for_status()
try:
# Save the image
output_path = os.path.join(output_dir, f"{image_name}.{ext}")
print("saving image to:", output_path)
with open(output_path, "wb") as f:
f.write(response.content)
print(f"Saved: {output_path}")
except Exception as e:
print(f"Error saving image {image_url}: {e}")
async def save_elaboration(self, images, output_dir) -> None:
with open(output_dir, "w", encoding="utf-8") as f:
json.dump(images, f, indent=2, ensure_ascii=False)
print("\nResults saved to extracted_images.json")
async def _get_element_context(self, page, img_element) -> tuple[str, str, str]:
"""
Extract textual context around an image element from text-containing tags.
Returns:
Tuple of (full_context, immediate_context, nearby_text) where:
- full_context: Text extracted with self.context_levels
- immediate_context: Text extracted with context_level=1
- nearby_text: Text within pixel_distance_threshold pixels of the image
"""
try:
# JavaScript function to check if element is visible
"""
Visibility Checks :
visibility CSS property - Excludes elements with visibility: hidden or visibility: collapse
display CSS property - Excludes elements with display: none
opacity CSS property - Excludes elements with opacity: 0
Element dimensions - Excludes elements with zero width or height (collapsed elements)
"""
visibility_check = """
function isVisible(el) {
if (!el) return false;
const style = window.getComputedStyle(el);
// Check visibility and display properties
if (style.visibility === 'hidden' || style.visibility === 'collapse') return false;
if (style.display === 'none') return false;
if (style.opacity === '0') return false;
// Check if element has dimensions
const rect = el.getBoundingClientRect();
if (rect.width === 0 || rect.height === 0) return false;
return true;
}
"""
# JavaScript function to extract text at a specific context level
def get_context_js(levels):
return f"""
(element) => {{
{visibility_check}
// Text-containing tags to extract
/*const textTags = ['p', 'span', 'div', 'a', 'li', 'td', 'th', 'h1', 'h2', 'h3', 'h4', 'h5', 'h6',
'label', 'figcaption', 'caption', 'blockquote', 'pre', 'code', 'em', 'strong',
'b', 'i', 'u', 'small', 'mark', 'sub', 'sup', 'time', 'article', 'section'];*/
const textTags = ['p', 'span', 'h1', 'h2', 'h3', 'h4', 'h5', 'h6', 'a'];
let textContent = [];
// Traverse up the DOM tree
let current = element;
for (let i = 0; i < {levels} && current.parentElement; i++) {{
current = current.parentElement;
}}
// Function to extract text from an element and its children
function extractText(el, depth = 0) {{
if (depth > {levels}) return;
// Skip if element is not visible
if (!isVisible(el)) return;
// Get direct text content of text-containing elements
if (textTags.includes(el.tagName.toLowerCase())) {{
const text = el.textContent.trim();
if (text && text.length > 0) {{
textContent.push({{
tag: el.tagName.toLowerCase(),
text: text
}});
}}
}}
// Recursively process children
for (let child of el.children) {{
extractText(child, depth + 1);
}}
}}
// Extract text from the context root
extractText(current);
// Format as readable text
//return textContent.map(item => `<${{item.tag}}>: ${{item.text}}`).join('\\n\\n');
return textContent.map(item => `<${{item.tag}}>: ${{item.text}}`).join(' ');
}}
"""
# JavaScript function to extract nearby text based on pixel distance
nearby_text_js = f"""
(element) => {{
{visibility_check}
/*const textTags = ['p', 'span', 'div', 'a', 'li', 'td', 'th', 'h1', 'h2', 'h3', 'h4', 'h5', 'h6',
'label', 'figcaption', 'caption', 'blockquote', 'pre', 'code', 'em', 'strong',
'b', 'i', 'u', 'small', 'mark', 'sub', 'sup', 'time'];*/
const textTags = ['p', 'span', 'h1', 'h2', 'h3', 'h4', 'h5', 'h6', 'a'];
const threshold = {self.pixel_distance_threshold};
const imgRect = element.getBoundingClientRect();
const imgCenterX = imgRect.left + imgRect.width / 2;
const imgCenterY = imgRect.top + imgRect.height / 2;
// Calculate distance between two rectangles.
function getDistance(rect1, rect2) {{
// Get centers
const x1 = rect1.left + rect1.width / 2;
const y1 = rect1.top + rect1.height / 2;
const x2 = rect2.left + rect2.width / 2;
const y2 = rect2.top + rect2.height / 2;
// Euclidean distance
return Math.sqrt(Math.pow(x2 - x1, 2) + Math.pow(y2 - y1, 2)); //This can be changed considering not only the distance between the centers but maybe the nearest points
}}
let nearbyElements = [];
// Find all text elements on the page
const allElements = document.querySelectorAll(textTags.join(','));
allElements.forEach(el => {{
// Skip if element is not visible
if (!isVisible(el)) return;
const text = el.textContent.trim();
if (!text || text.length === 0) return;
// Skip if it's the image itself or contains the image
if (el === element || el.contains(element)) return;
const elRect = el.getBoundingClientRect();
const distance = getDistance(imgRect, elRect);
if (distance <= threshold) {{
nearbyElements.push({{
tag: el.tagName.toLowerCase(),
text: text,
distance: Math.round(distance)
}});
}}
}});
// Sort by distance
nearbyElements.sort((a, b) => a.distance - b.distance);
// Format output
//return nearbyElements.map(item =>
// `<${{item.tag}}> [${{item.distance}}px]: ${{item.text}}`
//).join('\\n\\n');
return nearbyElements.map(item =>
`<${{item.tag}}> [${{item.distance}}px]: ${{item.text}}`
).join(' ');
}}
"""
# Get full context with self.context_levels
full_context_js = get_context_js(self.context_levels)
full_context = await img_element.evaluate(full_context_js)
full_context = full_context if full_context else "No textual context found"
# Get immediate context with level=1
immediate_context_js = get_context_js(1)
immediate_context = await img_element.evaluate(immediate_context_js)
immediate_context = (
immediate_context if immediate_context else "No immediate context found"
)
# Get nearby text based on pixel distance
nearby_text = await img_element.evaluate(nearby_text_js)
nearby_text = nearby_text if nearby_text else "No nearby text found"
return full_context, immediate_context, nearby_text
except Exception as e:
error_msg = f"Error extracting context: {str(e)}"
return error_msg, error_msg, error_msg
async def _get_page_metadata(self, page):
"""Extract page metadata in one fast evaluate call. Batch DOM extraction inside one evaluate()."""
return await page.evaluate(
"""
() => {
const metadata = {
title: document.title || null,
description: null,
keywords: null,
headings: []
};
const desc = document.querySelector('meta[name="description"]');
const keys = document.querySelector('meta[name="keywords"]');
metadata.description = desc?.content || null;
metadata.keywords = keys?.content || null;
// Collect all headings h1h6
const allHeadings = document.querySelectorAll('h1, h2, h3, h4, h5, h6');
metadata.headings = Array.from(allHeadings)
.map(h => ({
level: parseInt(h.tagName.substring(1), 10),
text: h.textContent.trim()
}))
.filter(h => h.text.length > 0);
return metadata;
}
"""
)
async def extract_images(
self, extract_context=True, specific_images_urls=[]
) -> List[Dict]:
"""
Extract all images from the page with their metadata and context.
Returns:
List of dictionaries containing image information
"""
async with async_playwright() as p:
browser = await p.chromium.launch(headless=True)
page = await browser.new_page()
try:
# await page.goto(self.url, wait_until="networkidle") # method 1: use if the page has unpredictable async content and there is the need to ensure everything loads
# The "networkidle" approach is generally more robust but slower, while the fixed timeout is faster but less adaptive to actual page behavior.
# ---alternative method2: use if there is total awareness of the page's loading pattern and want faster, more reliable execution
await page.goto(
self.url, timeout=50000, wait_until="load"
) # deafult timeout=30000, 30sec
# Wait for page to load completely
await page.wait_for_timeout(2000) # Wait for dynamic content
# -----
if extract_context:
print("Getting page metadata...")
# Get page metadata once
page_metadata = await self._get_page_metadata(page)
page_title = page_metadata["title"]
page_description = page_metadata["description"]
page_keywords = page_metadata["keywords"]
page_headings = page_metadata["headings"]
else:
page_title = ""
page_description = ""
page_keywords = ""
page_headings = []
if len(specific_images_urls) == 0:
# Find all img elements
print("Extracting all images from the page", self.url)
# img_elements = await page.locator("img").all()
else:
print(
"Extracting specific images from the page:",
self.url,
specific_images_urls,
)
# img_elements = await page.locator("img").all()
""" # method 3: optimized approach
# Get all src attributes in one go
all_img_elements = await page.locator("img").all()
all_srcs = await page.locator("img").evaluate_all(
"elements => elements.map(el => el.src || '')"
)
# Filter with the pre-fetched src values
img_elements = [
elem for elem, src in zip(all_img_elements, all_srcs)
if src in specific_images_urls
]
"""
""" #method 2: single pass to find matching images
for img_element in all_img_elements: #This is more efficient than making separate locator queries for each specific URL and avoids timeout issues.
try:
src = await img_element.get_attribute("src")
print("found image src:", src)
if src in specific_images_urls:
img_elements.append(img_element)
except Exception as e:
print(f"Error getting src attribute from image: {str(e)}")"""
""" # method 1: separate locator queries for each specific URL
for url in specific_images_urls:
try:
img_element = await page.locator(
f'img[src="{url}"]'
).first.element_handle(timeout=0) # Use first() to get only the first match; 0 timeout=No timeout
if img_element:
img_elements.append(img_element)
except Exception as e:
print(f"Error locating image with src {url}: {str(e)}")"""
img_elements = await page.locator(
"img"
).all() # unified approach to start with all images and filter later
image_source_list = [] # avoid multiple check for the same image url
images_data = []
for img in img_elements:
if (
len(images_data) >= self.number_of_images
): # limits the effective image list based on the ini param.
print(
"Reached the maximum number of images to extract.",
self.number_of_images,
)
break
try:
# Get image src
src = await img.get_attribute("src")
if not src:
print("image has no src attribute. Skipped.")
continue
if (
src not in specific_images_urls
and len(specific_images_urls) > 0
):
# print("image src",src,"not in the specific images list. Skipped.")
continue
if src not in image_source_list:
image_source_list.append(src)
else:
print("image src", src, "already processed. Skipped.")
continue
# Convert relative URLs to absolute
img_url = urljoin(self.url, src)
# Verify format
if not self._is_supported_format(img_url):
print(
"image format not supported for url:",
img_url,
". Skipped.",
)
continue
if disclaim_bool_string(self.save_images) == True:
print("save image:", img_url.split("/")[-1])
await self._download_image(
image_url=img_url, output_dir=self.save_images_path
)
# Get alt text
alt_text = await img.get_attribute("alt") or ""
if extract_context:
print("Extracting context for image:", img_url)
# Get surrounding HTML context (full, immediate, and nearby)
html_context, immediate_context, nearby_text = (
await self._get_element_context(page, img)
)
else:
html_context, immediate_context, nearby_text = "", "", ""
# Compile image data
image_info = {
"url": img_url,
"alt_text": alt_text,
"html_context": html_context,
"immediate_context": immediate_context,
"nearby_text": nearby_text,
"page_url": self.url,
"page_title": page_title,
"page_description": page_description,
"page_keywords": page_keywords,
"page_headings": page_headings,
}
images_data.append(image_info)
except Exception as e:
print(f"Error processing image: {str(e)}")
continue
return images_data
finally:
await browser.close()
async def main(args):
url = args.page_url
context_levels = args.context_levels
pixel_distance_threshold = args.pixel_distance_threshold
number_of_images = args.number_of_images
save_images = args.save_images
print(
"call ImageExtrcator with-",
"page_url:",
url,
"context_levels:",
context_levels,
"pixel_distance_threshold:",
pixel_distance_threshold,
"number_of_images:",
number_of_images,
"save_images:",
save_images,
)
if (
disclaim_bool_string(args.save_elaboration) == True
or disclaim_bool_string(args.save_images) == True
): # if something to save
url_path = url.replace(":", "").replace("//", "_").replace("/", "_")
now = datetime.now(timezone.utc)
now_str = now.strftime("%Y_%m_%d-%H_%M_%S")
output_dir = prepare_output_folder(url_path, now_str)
if disclaim_bool_string(args.save_images) == True:
images_output_dir = create_folder(
output_dir, directory_separator="/", next_path="images"
)
print("save images path:", images_output_dir)
# Create extractor
extractor = ImageExtractor(
url,
context_levels=context_levels,
pixel_distance_threshold=pixel_distance_threshold,
number_of_images=number_of_images,
save_images=save_images,
save_images_path=images_output_dir,
)
# Extract images
print(f"Extracting images from: {url}")
images = await extractor.extract_images(specific_images_urls=[])
print(f"\nFound {len(images)} supported images\n")
# Display results
for i, img in enumerate(images, 1):
print(f"Image {i}:")
print(f" URL: {img['url']}")
print(f" Alt text: {img['alt_text']}")
print(f" Page title: {img['page_title']}")
print(f" Full context length: {len(img['html_context'])} characters")
print(f" Immediate context length: {len(img['immediate_context'])} characters")
print(f" Nearby text length: {len(img['nearby_text'])} characters")
print(f" Number of headings on page: {len(img['page_headings'])}")
print("-" * 80)
if disclaim_bool_string(args.save_elaboration) == True: # Optionally save to JSON
await extractor.save_elaboration(
images, output_dir=output_dir + "/extracted_images.json"
)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--page_url",
type=str,
help=("Url page to analyze"),
default="https://www.bbc.com",
)
parser.add_argument(
"--context_levels",
type=int,
default=5,
help=("HTML context levels around the image"),
)
parser.add_argument(
"--pixel_distance_threshold",
type=int,
default=200,
help=("pixel distance threshold around the image"),
)
parser.add_argument(
"--number_of_images",
type=int,
default=10,
help=("max number of desired images"),
)
parser.add_argument(
"--save_elaboration",
action="store_true",
default=True,
help=("If True save the elaborated info in a json file"),
)
parser.add_argument(
"--save_images",
action="store_true",
default=True,
help=("If True save the images"),
)
args = parser.parse_args()
asyncio.run(main(args))