wcag_AI_validation/UI/user_task_assignment/user_assignment_manager.py

634 lines
20 KiB
Python

"""
User Assignment Manager
=======================
Manages user-to-assignment mapping with automatic Round 2+ generation.
This module:
- Maintains assignment data in SQLite database
- Loads existing assignments from JSON file (the exported output of the AssignmentSystem)
- Automatically generates new assignments when user count exceeds current capacity
- Retrieves assignments for registered users
Usage
-----
manager = UserAssignmentManager(
db_path="wcag_validator_ui.db",
config_json_path="sites_config.json",
assignments_json_path="alt_text_assignments_output_target_overlap.json",
assignments_xlsx_path="alt_text_assignments_output_target_overlap.xlsx"
)
# Get assignments for a user (auto-generates if needed)
assignments = manager.get_user_assignments("user123")
# Register new active users
manager.register_active_users(["user1", "user2", "user3"])
"""
import json
import sqlite3
from pathlib import Path
from typing import Dict, List, Optional, Tuple
from collections import defaultdict
# Import the assignment system
import sys
from .alt_text_assignment_target_overlap_multiple_round import (
AssignmentSystem,
SiteConfig,
)
class UserAssignmentManager:
"""
Manages user-to-assignment mapping with SQLite persistence.
Automatically handles Round 2+ mode when user count exceeds current capacity.
"""
def __init__(
self,
db_path: str,
config_json_path: str,
assignments_json_path: str = "alt_text_assignments_output_target_overlap.json",
assignments_xlsx_path: str = "alt_text_assignments_output_target_overlap.xlsx",
target_overlap: int = 2,
seed: int = 42,
):
"""
Initialize the User Assignment Manager.
Parameters
----------
db_path : str
Path to SQLite database file
config_json_path : str
Path to sites configuration JSON file (from --config_json)
assignments_json_path : str
Path to output assignments JSON file
assignments_xlsx_path : str
Path to output assignments XLSX file
target_overlap : int
Minimum overlap between user image assignments
seed : int
Random seed for reproducibility
"""
self.db_path = Path(db_path)
self.config_json_path = Path(config_json_path)
self.assignments_json_path = Path(assignments_json_path)
self.assignments_xlsx_path = Path(assignments_xlsx_path)
self.target_overlap = target_overlap
self.seed = seed
# Load configuration
self.sites_config = self._load_sites_config()
# Initialize database
self._init_database()
# Load existing assignments from JSON if available
self._load_existing_assignments()
def _load_sites_config(self) -> List[SiteConfig]:
"""Load site configuration from JSON."""
with open(self.config_json_path, "r", encoding="utf-8") as f:
site_defs = json.load(f)
return [
SiteConfig(
url=sd["url"],
allowed_images=sd["allowed_images"],
images_per_user=sd.get("images_per_user", 6),
)
for sd in site_defs
]
def _init_database(self):
"""Initialize SQLite database with user_assignments table."""
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
# Create user_assignments table
cursor.execute(
"""
CREATE TABLE IF NOT EXISTS user_assignments (
id INTEGER PRIMARY KEY AUTOINCREMENT,
user_id TEXT NOT NULL,
site_url TEXT NOT NULL,
image_indices TEXT NOT NULL,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
UNIQUE(user_id, site_url)
)
"""
)
# Create index for fast user lookups
cursor.execute(
"""
CREATE INDEX IF NOT EXISTS idx_user_assignments_user_id
ON user_assignments(user_id)
"""
)
# Create assignment_generation_log table to track Round 2+ runs
cursor.execute(
"""
CREATE TABLE IF NOT EXISTS assignment_generation_log (
id INTEGER PRIMARY KEY AUTOINCREMENT,
generation_round INTEGER NOT NULL,
users_before INTEGER NOT NULL,
users_after INTEGER NOT NULL,
new_users_added INTEGER NOT NULL,
generated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
json_file TEXT,
xlsx_file TEXT
)
"""
)
# Create table to map user_id and user_name
cursor.execute(
"""
CREATE TABLE IF NOT EXISTS user_info (
id INTEGER PRIMARY KEY AUTOINCREMENT,
user_id TEXT NOT NULL UNIQUE,
user_name TEXT
) """
)
conn.commit()
conn.close()
def _load_existing_assignments(self, active_user_names: Optional[List[str]] = None):
"""Load existing assignments from JSON file into database if not already there."""
if not self.assignments_json_path.exists():
return
with open(self.assignments_json_path, "r", encoding="utf-8") as f:
data = json.load(f)
assignments = data.get("assignments", {})
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
for user_id, sites_dict in assignments.items():
for site_url, image_indices in sites_dict.items():
# print(f"[DB] Loading assignment for user {user_id}, site {site_url}, "
# f"{image_indices} images")
try:
'''
cursor.execute("""
INSERT OR IGNORE INTO user_assignments
(user_id, site_url, image_indices)
VALUES (?, ?, ?)
""", (user_id, site_url, json.dumps(image_indices)))'''
cursor.execute(
"""
INSERT INTO user_assignments (user_id, site_url, image_indices)
VALUES (?, ?, ?)
ON CONFLICT(user_id, site_url) DO UPDATE SET
image_indices = excluded.image_indices,
updated_at = CURRENT_TIMESTAMP
""",
(user_id, site_url, json.dumps(image_indices)),
)
cursor.execute( # also update user_info table with user_name if active_user_names is provided and user_id starts with "user"
"""
INSERT INTO user_info (user_id, user_name)
VALUES (?, ?)
ON CONFLICT(user_id) DO UPDATE SET
user_name = excluded.user_name
""",
(
user_id,
(
active_user_names[int(user_id[4:]) - 1]
if active_user_names and user_id.startswith("user")
else None
),
),
)
except sqlite3.IntegrityError:
print(
f"[DB] Error. Skipping existing assignment for user {user_id}, site {site_url}"
)
pass
conn.commit()
conn.close()
def get_user_assignments(
self, user_id: str, from_user_name: bool = False
) -> Optional[Dict[str, List[int]]]:
"""
Get assignments for a user from database.
Parameters
----------
user_id : str
User ID
from_user_name : bool
If True, treat user_id as user_name and look up corresponding user_id in user_info table before fetching assignments
Returns
-------
dict or None
{site_url: [image_indices]} or None if user not found
"""
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
if from_user_name:
print(f"[DB] Looking up user_id for user_name: {user_id}")
cursor.execute(
"""
SELECT user_id
FROM user_info
WHERE user_name = ?
""",
(user_id,),
)
result = cursor.fetchone()
print(f"[DB] Lookup result for user_name '{user_id}': {result}")
if not result:
conn.close()
return None
user_id = result[0]
cursor.execute(
"""
SELECT site_url, image_indices
FROM user_assignments
WHERE user_id = ?
""",
(user_id,),
)
rows = cursor.fetchall()
conn.close()
if not rows:
return None
assignments = {}
for site_url, image_indices_json in rows:
assignments[site_url] = json.loads(image_indices_json)
return assignments
def get_all_user_ids(self) -> List[str]:
"""Get all registered user IDs from database."""
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
cursor.execute(
"""
SELECT DISTINCT user_id
FROM user_assignments
ORDER BY user_id
"""
)
user_ids = [row[0] for row in cursor.fetchall()]
conn.close()
return user_ids
def register_active_users(
self, active_user_ids: List[str]
) -> Dict[str, Dict[str, List[int]]]:
"""
Register active users and ensure assignments exist for all.
If user count exceeds current capacity:
1. Calls get_managed_user_count() to check capacity
2. Runs AssignmentSystem in Round 2+ mode if needed
3. Updates JSON/XLSX files
4. Logs the generation event
Parameters
----------
active_user_ids : list of str
List of currently active user IDs
Returns
-------
dict
{user_id: {site_url: [image_indices]}} for all active users
"""
managed_count = self.get_managed_user_count()
new_user_count = len(active_user_ids)
active_user_names = active_user_ids
print(f"\n[UserAssignmentManager] active_user_name: {active_user_names}")
# Check if we need to generate new assignments
if new_user_count > managed_count:
num_new_users = new_user_count - managed_count
print(f"\n[UserAssignmentManager] Expanding assignments:")
print(f" Current capacity: {managed_count} users")
print(f" Required capacity: {new_user_count} users")
print(f" Generating {num_new_users} additional assignments...\n")
self._generate_round2_assignments(num_new_users, active_user_names)
# Retrieve assignments for all active users
result = {}
for user_id in active_user_ids:
assignments = self.get_user_assignments(user_id)
if assignments is None:
print(f"[WARNING] No assignments found for user {user_id}. It is fine")
else:
result[user_id] = assignments
return result
def get_managed_user_count(self) -> int:
"""
Get the number of users currently managed by assignments.
Returns
-------
int
Number of unique users with assignments
"""
return len(self.get_all_user_ids())
def _generate_round2_assignments(
self, num_new_users: int, active_user_names: List[str]
):
"""
Generate Round 2+ assignments using AssignmentSystem.
Parameters
----------
num_new_users : int
Number of new users to add
"""
current_users = self.get_managed_user_count()
# Create AssignmentSystem with current site configuration
system = AssignmentSystem(
sites=self.sites_config,
target_overlap=self.target_overlap,
seed=self.seed,
)
# Load existing assignments
system = self._load_previous_assignments_into_system(system)
# Generate new users
new_user_ids = [f"user{current_users + i}" for i in range(1, num_new_users + 1)]
new_user_names = active_user_names[
current_users : current_users + num_new_users
]
print(
f"[AssignmentSystem] Adding users: {new_user_ids[0]} to {new_user_ids[-1]}"
)
print(
f"[AssignmentSystem] Corresponding names: {new_user_names[0]} to {new_user_names[-1]}"
)
for uid in new_user_ids:
system.add_user(uid)
# Save updated assignments
print(f"[AssignmentSystem] Saving to {self.assignments_json_path}")
system.to_json(str(self.assignments_json_path))
print(f"[AssignmentSystem] Saving to {self.assignments_xlsx_path}")
system.to_xlsx(str(self.assignments_xlsx_path))
# Load new assignments into database
self._load_existing_assignments(
active_user_names=active_user_names
) # pass active_user_names (entire list) to update user_info table with names
# Log the generation event
self._log_generation_event(
generation_round=self._get_generation_round() + 1,
users_before=current_users,
users_after=current_users + num_new_users,
new_users_added=num_new_users,
json_file=str(self.assignments_json_path),
xlsx_file=str(self.assignments_xlsx_path),
)
print(f"[UserAssignmentManager] Assignments updated successfully")
def _load_previous_assignments_into_system(
self, system: AssignmentSystem
) -> AssignmentSystem:
"""
Load previously exported assignments into an AssignmentSystem object.
Parameters
----------
system : AssignmentSystem
Already-constructed system with correct SiteConfig pool
Returns
-------
AssignmentSystem
System with existing assignments populated
"""
if not self.assignments_xlsx_path.exists():
print(
f"[AssignmentSystem] No previous assignments found at {self.assignments_xlsx_path}"
)
return system
print(
f"[AssignmentSystem] Loading previous assignments from {self.assignments_xlsx_path}"
)
# Use the _load_previous_assignments function from the module
from .alt_text_assignment_target_overlap_multiple_round import (
_load_previous_assignments,
)
return _load_previous_assignments(system, str(self.assignments_xlsx_path))
def _get_generation_round(self) -> int:
"""Get the current generation round from log."""
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
cursor.execute(
"""
SELECT MAX(generation_round)
FROM assignment_generation_log
"""
)
result = cursor.fetchone()[0]
conn.close()
return result if result is not None else 0
def _log_generation_event(
self,
generation_round: int,
users_before: int,
users_after: int,
new_users_added: int,
json_file: str,
xlsx_file: str,
):
"""Log a generation event to the database."""
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
cursor.execute(
"""
INSERT INTO assignment_generation_log
(generation_round, users_before, users_after, new_users_added, json_file, xlsx_file)
VALUES (?, ?, ?, ?, ?, ?)
""",
(
generation_round,
users_before,
users_after,
new_users_added,
json_file,
xlsx_file,
),
)
conn.commit()
conn.close()
print(
f"[DB] Logged generation event: Round {generation_round}, "
f"{users_before}{users_after} users"
)
def get_generation_history(self) -> List[Dict]:
"""Get the complete assignment generation history."""
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
cursor.execute(
"""
SELECT generation_round, users_before, users_after, new_users_added,
generated_at, json_file, xlsx_file
FROM assignment_generation_log
ORDER BY generation_round ASC
"""
)
columns = [desc[0] for desc in cursor.description]
history = [dict(zip(columns, row)) for row in cursor.fetchall()]
conn.close()
return history
def get_statistics(self) -> Dict:
"""Get statistics about user assignments."""
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
# Total unique users
cursor.execute("SELECT COUNT(DISTINCT user_id) FROM user_assignments")
total_users = cursor.fetchone()[0]
# Users per site
cursor.execute(
"""
SELECT site_url, COUNT(DISTINCT user_id) as user_count
FROM user_assignments
GROUP BY site_url
ORDER BY site_url
"""
)
users_per_site = {row[0]: row[1] for row in cursor.fetchall()}
# Average images per user per site
cursor.execute(
"""
SELECT site_url,
AVG(json_array_length(image_indices)) as avg_images
FROM user_assignments
GROUP BY site_url
ORDER BY site_url
"""
)
avg_images_per_site = {}
for row in cursor.fetchall():
try:
# Fallback for older SQLite versions
avg_images_per_site[row[0]] = row[1]
except (TypeError, IndexError):
pass
conn.close()
return {
"total_users": total_users,
"users_per_site": users_per_site,
"avg_images_per_site": avg_images_per_site,
}
if __name__ == "__main__":
"""Demo/test usage"""
import sys
# Set paths (adjust as needed for your environment)
base_dir = Path(__file__).parent
manager = UserAssignmentManager(
db_path=str(
Path(__file__).parent.parent / "persistence" / "wcag_validator_ui.db"
),
config_json_path=str(base_dir / "sites_config.json"),
assignments_json_path=str(
base_dir / "alt_text_assignments_output_target_overlap.json"
),
assignments_xlsx_path=str(
base_dir / "alt_text_assignments_output_target_overlap.xlsx"
),
)
print("\n=== User Assignment Manager Demo ===\n")
# Get current managed users
managed_users = manager.get_all_user_ids()
print(f"Currently managed users: {managed_users}")
print(f"Total managed users: {manager.get_managed_user_count()}\n")
# Define active users (including new ones)
active_users = [f"user{i}" for i in range(1, 8)]
print(f"Active users (including new): {active_users}\n")
# Register and get assignments
print("Registering active users...")
assignments = manager.register_active_users(active_users)
print(f"\nAssignments for {len(assignments)} users:")
for user_id in sorted(assignments.keys())[:3]: # Show first 3
print(f" {user_id}: {len(assignments[user_id])} sites")
# Get statistics
stats = manager.get_statistics()
print(f"\n=== Statistics ===")
print(f"Total users: {stats['total_users']}")
print(f"Users per site: {stats['users_per_site']}")
# Get history
history = manager.get_generation_history()
if history:
print(f"\n=== Generation History ===")
for event in history[-3:]:
print(
f" Round {event['generation_round']}: {event['users_after']} users "
f"({event['new_users_added']} new)"
)