wcag_AI_validation/wcag_validator_RESTserver.py

144 lines
3.8 KiB
Python

#### To launch the script
# "python wcag_validator_RESTserver.py"
import sys
import argparse
import logging
import uvicorn
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
import warnings
warnings.filterwarnings("ignore")
from dotenv import load_dotenv, find_dotenv
from restserver.routers import (
routes_health,
routes_local_db,
routes_wcag_alttext,
routes_extract_images,
)
from dependences.utils import (
db_persistence_startup,
return_from_env_valid,
disclaim_bool_string,
)
def server(connection_db, mllm_settings):
# instantiate the app
app = FastAPI(title="HIISlab WCAG Validator REST Server")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
health_routes = routes_health.HealthRoutes()
local_db_routes = routes_local_db.LocalDBRoutes(connection_db)
wcag_alttext_routes = routes_wcag_alttext.WCAGAltTextValuationRoutes(
connection_db, mllm_settings
)
extract_images_routes = routes_extract_images.ExtractImagesRoutes()
app.include_router(health_routes.router, prefix="")
app.include_router(local_db_routes.router, prefix="")
app.include_router(wcag_alttext_routes.router, prefix="")
app.include_router(extract_images_routes.router, prefix="")
return app
def app_startup():
connection_db = db_persistence_startup(table="wcag_validator_results")
if disclaim_bool_string(return_from_env_valid("USE_OPENAI_MODEL", "False")) == True:
openai_model = True
else:
openai_model = False
if openai_model:
mllm_end_point = return_from_env_valid("MLLM_END_POINT_OPENAI", "")
mllm_api_key = return_from_env_valid("MLLM_API_KEY_OPENAI", "")
mllm_model_id = return_from_env_valid("MLLM_MODEL_ID_OPENAI", "")
else:
mllm_end_point = return_from_env_valid("MLLM_END_POINT_LOCAL", "")
mllm_api_key = return_from_env_valid("MLLM_API_KEY_LOCAL", "")
mllm_model_id = return_from_env_valid("MLLM_MODEL_ID_LOCAL", "")
print("mllm_end_point:", mllm_end_point)
print("mllm_model_id:", mllm_model_id)
mllm_model_settings = {
"openai_model": openai_model,
"mllm_end_point": mllm_end_point,
"mllm_api_key": mllm_api_key,
"mllm_model_id": mllm_model_id,
}
return connection_db, mllm_model_settings
def run_server(
connection_db,
mllm_settings,
host,
port,
):
app = server(connection_db, mllm_settings)
uvicorn.run(app, host=host, port=port)
def cli(sys_argv):
# Import environment variables
env_path = find_dotenv(filename=".env")
if env_path == "":
print(
"rest server env path not found: service starting with the default params values"
)
_ = load_dotenv(env_path) # read .env file
connection_db, mllm_settings = app_startup()
default_rest_port = return_from_env_valid("rest_port", 8000)
default_rest_host = return_from_env_valid("rest_host", "0.0.0.0")
parser = argparse.ArgumentParser()
# ----------------
# Server parameters
# ----------------
parser.add_argument(
"-p",
"--port",
help="port for server (default: os.environ[REST_PORT] or 8000)",
default=default_rest_port,
type=int,
)
parser.add_argument(
"-H",
"--host",
help="host for server (default: os.environ[REST_HOST] or 0.0.0.0)",
default=default_rest_host,
)
args = parser.parse_args(sys_argv)
logging.info("service started with cli args:%s", args)
run_server(
connection_db,
mllm_settings,
args.host,
args.port,
)
if __name__ == "__main__":
cli(sys.argv[1:])