634 lines
20 KiB
Python
634 lines
20 KiB
Python
"""
|
|
User Assignment Manager
|
|
=======================
|
|
Manages user-to-assignment mapping with automatic Round 2+ generation.
|
|
|
|
This module:
|
|
- Maintains assignment data in SQLite database
|
|
- Loads existing assignments from JSON file (the exported output of the AssignmentSystem)
|
|
- Automatically generates new assignments when user count exceeds current capacity
|
|
- Retrieves assignments for registered users
|
|
|
|
Usage
|
|
-----
|
|
manager = UserAssignmentManager(
|
|
db_path="wcag_validator_ui.db",
|
|
config_json_path="sites_config.json",
|
|
assignments_json_path="alt_text_assignments_output_target_overlap.json",
|
|
assignments_xlsx_path="alt_text_assignments_output_target_overlap.xlsx"
|
|
)
|
|
|
|
# Get assignments for a user (auto-generates if needed)
|
|
assignments = manager.get_user_assignments("user123")
|
|
|
|
# Register new active users
|
|
manager.register_active_users(["user1", "user2", "user3"])
|
|
"""
|
|
|
|
import json
|
|
import sqlite3
|
|
from pathlib import Path
|
|
from typing import Dict, List, Optional, Tuple
|
|
from collections import defaultdict
|
|
|
|
# Import the assignment system
|
|
import sys
|
|
from .alt_text_assignment_target_overlap_multiple_round import (
|
|
AssignmentSystem,
|
|
SiteConfig,
|
|
)
|
|
|
|
|
|
class UserAssignmentManager:
|
|
"""
|
|
Manages user-to-assignment mapping with SQLite persistence.
|
|
|
|
Automatically handles Round 2+ mode when user count exceeds current capacity.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
db_path: str,
|
|
config_json_path: str,
|
|
assignments_json_path: str = "alt_text_assignments_output_target_overlap.json",
|
|
assignments_xlsx_path: str = "alt_text_assignments_output_target_overlap.xlsx",
|
|
target_overlap: int = 2,
|
|
seed: int = 42,
|
|
):
|
|
"""
|
|
Initialize the User Assignment Manager.
|
|
|
|
Parameters
|
|
----------
|
|
db_path : str
|
|
Path to SQLite database file
|
|
config_json_path : str
|
|
Path to sites configuration JSON file (from --config_json)
|
|
assignments_json_path : str
|
|
Path to output assignments JSON file
|
|
assignments_xlsx_path : str
|
|
Path to output assignments XLSX file
|
|
target_overlap : int
|
|
Minimum overlap between user image assignments
|
|
seed : int
|
|
Random seed for reproducibility
|
|
"""
|
|
self.db_path = Path(db_path)
|
|
self.config_json_path = Path(config_json_path)
|
|
self.assignments_json_path = Path(assignments_json_path)
|
|
self.assignments_xlsx_path = Path(assignments_xlsx_path)
|
|
self.target_overlap = target_overlap
|
|
self.seed = seed
|
|
|
|
# Load configuration
|
|
self.sites_config = self._load_sites_config()
|
|
|
|
# Initialize database
|
|
self._init_database()
|
|
|
|
# Load existing assignments from JSON if available
|
|
self._load_existing_assignments()
|
|
|
|
def _load_sites_config(self) -> List[SiteConfig]:
|
|
"""Load site configuration from JSON."""
|
|
with open(self.config_json_path, "r", encoding="utf-8") as f:
|
|
site_defs = json.load(f)
|
|
|
|
return [
|
|
SiteConfig(
|
|
url=sd["url"],
|
|
allowed_images=sd["allowed_images"],
|
|
images_per_user=sd.get("images_per_user", 6),
|
|
)
|
|
for sd in site_defs
|
|
]
|
|
|
|
def _init_database(self):
|
|
"""Initialize SQLite database with user_assignments table."""
|
|
conn = sqlite3.connect(self.db_path)
|
|
cursor = conn.cursor()
|
|
|
|
# Create user_assignments table
|
|
cursor.execute(
|
|
"""
|
|
CREATE TABLE IF NOT EXISTS user_assignments (
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
user_id TEXT NOT NULL,
|
|
site_url TEXT NOT NULL,
|
|
image_indices TEXT NOT NULL,
|
|
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
|
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
|
UNIQUE(user_id, site_url)
|
|
)
|
|
"""
|
|
)
|
|
|
|
# Create index for fast user lookups
|
|
cursor.execute(
|
|
"""
|
|
CREATE INDEX IF NOT EXISTS idx_user_assignments_user_id
|
|
ON user_assignments(user_id)
|
|
"""
|
|
)
|
|
|
|
# Create assignment_generation_log table to track Round 2+ runs
|
|
cursor.execute(
|
|
"""
|
|
CREATE TABLE IF NOT EXISTS assignment_generation_log (
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
generation_round INTEGER NOT NULL,
|
|
users_before INTEGER NOT NULL,
|
|
users_after INTEGER NOT NULL,
|
|
new_users_added INTEGER NOT NULL,
|
|
generated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
|
json_file TEXT,
|
|
xlsx_file TEXT
|
|
)
|
|
"""
|
|
)
|
|
|
|
# Create table to map user_id and user_name
|
|
cursor.execute(
|
|
"""
|
|
CREATE TABLE IF NOT EXISTS user_info (
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
user_id TEXT NOT NULL UNIQUE,
|
|
user_name TEXT
|
|
) """
|
|
)
|
|
|
|
conn.commit()
|
|
conn.close()
|
|
|
|
def _load_existing_assignments(self, active_user_names: Optional[List[str]] = None):
|
|
"""Load existing assignments from JSON file into database if not already there."""
|
|
if not self.assignments_json_path.exists():
|
|
return
|
|
|
|
with open(self.assignments_json_path, "r", encoding="utf-8") as f:
|
|
data = json.load(f)
|
|
|
|
assignments = data.get("assignments", {})
|
|
conn = sqlite3.connect(self.db_path)
|
|
cursor = conn.cursor()
|
|
|
|
for user_id, sites_dict in assignments.items():
|
|
for site_url, image_indices in sites_dict.items():
|
|
# print(f"[DB] Loading assignment for user {user_id}, site {site_url}, "
|
|
# f"{image_indices} images")
|
|
try:
|
|
'''
|
|
cursor.execute("""
|
|
INSERT OR IGNORE INTO user_assignments
|
|
(user_id, site_url, image_indices)
|
|
VALUES (?, ?, ?)
|
|
""", (user_id, site_url, json.dumps(image_indices)))'''
|
|
|
|
cursor.execute(
|
|
"""
|
|
INSERT INTO user_assignments (user_id, site_url, image_indices)
|
|
VALUES (?, ?, ?)
|
|
ON CONFLICT(user_id, site_url) DO UPDATE SET
|
|
image_indices = excluded.image_indices,
|
|
updated_at = CURRENT_TIMESTAMP
|
|
""",
|
|
(user_id, site_url, json.dumps(image_indices)),
|
|
)
|
|
|
|
cursor.execute( # also update user_info table with user_name if active_user_names is provided and user_id starts with "user"
|
|
"""
|
|
INSERT INTO user_info (user_id, user_name)
|
|
VALUES (?, ?)
|
|
ON CONFLICT(user_id) DO UPDATE SET
|
|
user_name = excluded.user_name
|
|
""",
|
|
(
|
|
user_id,
|
|
(
|
|
active_user_names[int(user_id[4:]) - 1]
|
|
if active_user_names and user_id.startswith("user")
|
|
else None
|
|
),
|
|
),
|
|
)
|
|
|
|
except sqlite3.IntegrityError:
|
|
print(
|
|
f"[DB] Error. Skipping existing assignment for user {user_id}, site {site_url}"
|
|
)
|
|
pass
|
|
|
|
conn.commit()
|
|
conn.close()
|
|
|
|
def get_user_assignments(
|
|
self, user_id: str, from_user_name: bool = False
|
|
) -> Optional[Dict[str, List[int]]]:
|
|
"""
|
|
Get assignments for a user from database.
|
|
|
|
Parameters
|
|
----------
|
|
user_id : str
|
|
User ID
|
|
from_user_name : bool
|
|
If True, treat user_id as user_name and look up corresponding user_id in user_info table before fetching assignments
|
|
Returns
|
|
-------
|
|
dict or None
|
|
{site_url: [image_indices]} or None if user not found
|
|
"""
|
|
conn = sqlite3.connect(self.db_path)
|
|
cursor = conn.cursor()
|
|
|
|
if from_user_name:
|
|
print(f"[DB] Looking up user_id for user_name: {user_id}")
|
|
|
|
cursor.execute(
|
|
"""
|
|
SELECT user_id
|
|
FROM user_info
|
|
WHERE user_name = ?
|
|
""",
|
|
(user_id,),
|
|
)
|
|
result = cursor.fetchone()
|
|
print(f"[DB] Lookup result for user_name '{user_id}': {result}")
|
|
if not result:
|
|
conn.close()
|
|
return None
|
|
user_id = result[0]
|
|
|
|
cursor.execute(
|
|
"""
|
|
SELECT site_url, image_indices
|
|
FROM user_assignments
|
|
WHERE user_id = ?
|
|
""",
|
|
(user_id,),
|
|
)
|
|
|
|
rows = cursor.fetchall()
|
|
conn.close()
|
|
|
|
if not rows:
|
|
return None
|
|
|
|
assignments = {}
|
|
for site_url, image_indices_json in rows:
|
|
assignments[site_url] = json.loads(image_indices_json)
|
|
|
|
return assignments
|
|
|
|
def get_all_user_ids(self) -> List[str]:
|
|
"""Get all registered user IDs from database."""
|
|
conn = sqlite3.connect(self.db_path)
|
|
cursor = conn.cursor()
|
|
|
|
cursor.execute(
|
|
"""
|
|
SELECT DISTINCT user_id
|
|
FROM user_assignments
|
|
ORDER BY user_id
|
|
"""
|
|
)
|
|
|
|
user_ids = [row[0] for row in cursor.fetchall()]
|
|
conn.close()
|
|
|
|
return user_ids
|
|
|
|
def register_active_users(
|
|
self, active_user_ids: List[str]
|
|
) -> Dict[str, Dict[str, List[int]]]:
|
|
"""
|
|
Register active users and ensure assignments exist for all.
|
|
|
|
If user count exceeds current capacity:
|
|
1. Calls get_managed_user_count() to check capacity
|
|
2. Runs AssignmentSystem in Round 2+ mode if needed
|
|
3. Updates JSON/XLSX files
|
|
4. Logs the generation event
|
|
|
|
Parameters
|
|
----------
|
|
active_user_ids : list of str
|
|
List of currently active user IDs
|
|
|
|
Returns
|
|
-------
|
|
dict
|
|
{user_id: {site_url: [image_indices]}} for all active users
|
|
"""
|
|
managed_count = self.get_managed_user_count()
|
|
new_user_count = len(active_user_ids)
|
|
|
|
active_user_names = active_user_ids
|
|
print(f"\n[UserAssignmentManager] active_user_name: {active_user_names}")
|
|
|
|
# Check if we need to generate new assignments
|
|
if new_user_count > managed_count:
|
|
num_new_users = new_user_count - managed_count
|
|
print(f"\n[UserAssignmentManager] Expanding assignments:")
|
|
print(f" Current capacity: {managed_count} users")
|
|
print(f" Required capacity: {new_user_count} users")
|
|
print(f" Generating {num_new_users} additional assignments...\n")
|
|
|
|
self._generate_round2_assignments(num_new_users, active_user_names)
|
|
|
|
# Retrieve assignments for all active users
|
|
result = {}
|
|
for user_id in active_user_ids:
|
|
assignments = self.get_user_assignments(user_id)
|
|
if assignments is None:
|
|
print(f"[WARNING] No assignments found for user {user_id}. It is fine")
|
|
else:
|
|
result[user_id] = assignments
|
|
|
|
return result
|
|
|
|
def get_managed_user_count(self) -> int:
|
|
"""
|
|
Get the number of users currently managed by assignments.
|
|
|
|
Returns
|
|
-------
|
|
int
|
|
Number of unique users with assignments
|
|
"""
|
|
return len(self.get_all_user_ids())
|
|
|
|
def _generate_round2_assignments(
|
|
self, num_new_users: int, active_user_names: List[str]
|
|
):
|
|
"""
|
|
Generate Round 2+ assignments using AssignmentSystem.
|
|
|
|
Parameters
|
|
----------
|
|
num_new_users : int
|
|
Number of new users to add
|
|
"""
|
|
current_users = self.get_managed_user_count()
|
|
|
|
# Create AssignmentSystem with current site configuration
|
|
system = AssignmentSystem(
|
|
sites=self.sites_config,
|
|
target_overlap=self.target_overlap,
|
|
seed=self.seed,
|
|
)
|
|
|
|
# Load existing assignments
|
|
system = self._load_previous_assignments_into_system(system)
|
|
|
|
# Generate new users
|
|
new_user_ids = [f"user{current_users + i}" for i in range(1, num_new_users + 1)]
|
|
new_user_names = active_user_names[
|
|
current_users : current_users + num_new_users
|
|
]
|
|
|
|
print(
|
|
f"[AssignmentSystem] Adding users: {new_user_ids[0]} to {new_user_ids[-1]}"
|
|
)
|
|
print(
|
|
f"[AssignmentSystem] Corresponding names: {new_user_names[0]} to {new_user_names[-1]}"
|
|
)
|
|
|
|
for uid in new_user_ids:
|
|
system.add_user(uid)
|
|
|
|
# Save updated assignments
|
|
print(f"[AssignmentSystem] Saving to {self.assignments_json_path}")
|
|
system.to_json(str(self.assignments_json_path))
|
|
|
|
print(f"[AssignmentSystem] Saving to {self.assignments_xlsx_path}")
|
|
system.to_xlsx(str(self.assignments_xlsx_path))
|
|
|
|
# Load new assignments into database
|
|
self._load_existing_assignments(
|
|
active_user_names=active_user_names
|
|
) # pass active_user_names (entire list) to update user_info table with names
|
|
|
|
# Log the generation event
|
|
self._log_generation_event(
|
|
generation_round=self._get_generation_round() + 1,
|
|
users_before=current_users,
|
|
users_after=current_users + num_new_users,
|
|
new_users_added=num_new_users,
|
|
json_file=str(self.assignments_json_path),
|
|
xlsx_file=str(self.assignments_xlsx_path),
|
|
)
|
|
|
|
print(f"[UserAssignmentManager] Assignments updated successfully")
|
|
|
|
def _load_previous_assignments_into_system(
|
|
self, system: AssignmentSystem
|
|
) -> AssignmentSystem:
|
|
"""
|
|
Load previously exported assignments into an AssignmentSystem object.
|
|
|
|
Parameters
|
|
----------
|
|
system : AssignmentSystem
|
|
Already-constructed system with correct SiteConfig pool
|
|
|
|
Returns
|
|
-------
|
|
AssignmentSystem
|
|
System with existing assignments populated
|
|
"""
|
|
if not self.assignments_xlsx_path.exists():
|
|
print(
|
|
f"[AssignmentSystem] No previous assignments found at {self.assignments_xlsx_path}"
|
|
)
|
|
return system
|
|
|
|
print(
|
|
f"[AssignmentSystem] Loading previous assignments from {self.assignments_xlsx_path}"
|
|
)
|
|
|
|
# Use the _load_previous_assignments function from the module
|
|
from .alt_text_assignment_target_overlap_multiple_round import (
|
|
_load_previous_assignments,
|
|
)
|
|
|
|
return _load_previous_assignments(system, str(self.assignments_xlsx_path))
|
|
|
|
def _get_generation_round(self) -> int:
|
|
"""Get the current generation round from log."""
|
|
conn = sqlite3.connect(self.db_path)
|
|
cursor = conn.cursor()
|
|
|
|
cursor.execute(
|
|
"""
|
|
SELECT MAX(generation_round)
|
|
FROM assignment_generation_log
|
|
"""
|
|
)
|
|
|
|
result = cursor.fetchone()[0]
|
|
conn.close()
|
|
|
|
return result if result is not None else 0
|
|
|
|
def _log_generation_event(
|
|
self,
|
|
generation_round: int,
|
|
users_before: int,
|
|
users_after: int,
|
|
new_users_added: int,
|
|
json_file: str,
|
|
xlsx_file: str,
|
|
):
|
|
"""Log a generation event to the database."""
|
|
conn = sqlite3.connect(self.db_path)
|
|
cursor = conn.cursor()
|
|
|
|
cursor.execute(
|
|
"""
|
|
INSERT INTO assignment_generation_log
|
|
(generation_round, users_before, users_after, new_users_added, json_file, xlsx_file)
|
|
VALUES (?, ?, ?, ?, ?, ?)
|
|
""",
|
|
(
|
|
generation_round,
|
|
users_before,
|
|
users_after,
|
|
new_users_added,
|
|
json_file,
|
|
xlsx_file,
|
|
),
|
|
)
|
|
|
|
conn.commit()
|
|
conn.close()
|
|
|
|
print(
|
|
f"[DB] Logged generation event: Round {generation_round}, "
|
|
f"{users_before}→{users_after} users"
|
|
)
|
|
|
|
def get_generation_history(self) -> List[Dict]:
|
|
"""Get the complete assignment generation history."""
|
|
conn = sqlite3.connect(self.db_path)
|
|
cursor = conn.cursor()
|
|
|
|
cursor.execute(
|
|
"""
|
|
SELECT generation_round, users_before, users_after, new_users_added,
|
|
generated_at, json_file, xlsx_file
|
|
FROM assignment_generation_log
|
|
ORDER BY generation_round ASC
|
|
"""
|
|
)
|
|
|
|
columns = [desc[0] for desc in cursor.description]
|
|
history = [dict(zip(columns, row)) for row in cursor.fetchall()]
|
|
conn.close()
|
|
|
|
return history
|
|
|
|
def get_statistics(self) -> Dict:
|
|
"""Get statistics about user assignments."""
|
|
conn = sqlite3.connect(self.db_path)
|
|
cursor = conn.cursor()
|
|
|
|
# Total unique users
|
|
cursor.execute("SELECT COUNT(DISTINCT user_id) FROM user_assignments")
|
|
total_users = cursor.fetchone()[0]
|
|
|
|
# Users per site
|
|
cursor.execute(
|
|
"""
|
|
SELECT site_url, COUNT(DISTINCT user_id) as user_count
|
|
FROM user_assignments
|
|
GROUP BY site_url
|
|
ORDER BY site_url
|
|
"""
|
|
)
|
|
|
|
users_per_site = {row[0]: row[1] for row in cursor.fetchall()}
|
|
|
|
# Average images per user per site
|
|
cursor.execute(
|
|
"""
|
|
SELECT site_url,
|
|
AVG(json_array_length(image_indices)) as avg_images
|
|
FROM user_assignments
|
|
GROUP BY site_url
|
|
ORDER BY site_url
|
|
"""
|
|
)
|
|
|
|
avg_images_per_site = {}
|
|
for row in cursor.fetchall():
|
|
try:
|
|
# Fallback for older SQLite versions
|
|
avg_images_per_site[row[0]] = row[1]
|
|
except (TypeError, IndexError):
|
|
pass
|
|
|
|
conn.close()
|
|
|
|
return {
|
|
"total_users": total_users,
|
|
"users_per_site": users_per_site,
|
|
"avg_images_per_site": avg_images_per_site,
|
|
}
|
|
|
|
|
|
if __name__ == "__main__":
|
|
"""Demo/test usage"""
|
|
import sys
|
|
|
|
# Set paths (adjust as needed for your environment)
|
|
base_dir = Path(__file__).parent
|
|
|
|
manager = UserAssignmentManager(
|
|
db_path=str(
|
|
Path(__file__).parent.parent / "persistence" / "wcag_validator_ui.db"
|
|
),
|
|
config_json_path=str(base_dir / "sites_config.json"),
|
|
assignments_json_path=str(
|
|
base_dir / "alt_text_assignments_output_target_overlap.json"
|
|
),
|
|
assignments_xlsx_path=str(
|
|
base_dir / "alt_text_assignments_output_target_overlap.xlsx"
|
|
),
|
|
)
|
|
|
|
print("\n=== User Assignment Manager Demo ===\n")
|
|
|
|
# Get current managed users
|
|
managed_users = manager.get_all_user_ids()
|
|
print(f"Currently managed users: {managed_users}")
|
|
print(f"Total managed users: {manager.get_managed_user_count()}\n")
|
|
|
|
# Define active users (including new ones)
|
|
active_users = [f"user{i}" for i in range(1, 8)]
|
|
print(f"Active users (including new): {active_users}\n")
|
|
|
|
# Register and get assignments
|
|
print("Registering active users...")
|
|
assignments = manager.register_active_users(active_users)
|
|
|
|
print(f"\nAssignments for {len(assignments)} users:")
|
|
for user_id in sorted(assignments.keys())[:3]: # Show first 3
|
|
print(f" {user_id}: {len(assignments[user_id])} sites")
|
|
|
|
# Get statistics
|
|
stats = manager.get_statistics()
|
|
print(f"\n=== Statistics ===")
|
|
print(f"Total users: {stats['total_users']}")
|
|
print(f"Users per site: {stats['users_per_site']}")
|
|
|
|
# Get history
|
|
history = manager.get_generation_history()
|
|
if history:
|
|
print(f"\n=== Generation History ===")
|
|
for event in history[-3:]:
|
|
print(
|
|
f" Round {event['generation_round']}: {event['users_after']} users "
|
|
f"({event['new_users_added']} new)"
|
|
)
|