144 lines
3.8 KiB
Python
144 lines
3.8 KiB
Python
#### To launch the script
|
|
# "python wcag_validator_RESTserver.py"
|
|
|
|
|
|
import sys
|
|
import argparse
|
|
import logging
|
|
import uvicorn
|
|
from fastapi import FastAPI
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
import warnings
|
|
|
|
warnings.filterwarnings("ignore")
|
|
from dotenv import load_dotenv, find_dotenv
|
|
|
|
|
|
from restserver.routers import (
|
|
routes_health,
|
|
routes_local_db,
|
|
routes_wcag_alttext,
|
|
routes_extract_images,
|
|
)
|
|
|
|
from dependences.utils import (
|
|
db_persistence_startup,
|
|
return_from_env_valid,
|
|
disclaim_bool_string,
|
|
)
|
|
|
|
|
|
def server(connection_db, mllm_settings):
|
|
# instantiate the app
|
|
app = FastAPI(title="HIISlab WCAG Validator REST Server")
|
|
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=["*"],
|
|
allow_credentials=True,
|
|
allow_methods=["*"],
|
|
allow_headers=["*"],
|
|
)
|
|
|
|
health_routes = routes_health.HealthRoutes()
|
|
local_db_routes = routes_local_db.LocalDBRoutes(connection_db)
|
|
wcag_alttext_routes = routes_wcag_alttext.WCAGAltTextValuationRoutes(
|
|
connection_db, mllm_settings
|
|
)
|
|
extract_images_routes = routes_extract_images.ExtractImagesRoutes()
|
|
|
|
app.include_router(health_routes.router, prefix="")
|
|
app.include_router(local_db_routes.router, prefix="")
|
|
app.include_router(wcag_alttext_routes.router, prefix="")
|
|
app.include_router(extract_images_routes.router, prefix="")
|
|
return app
|
|
|
|
|
|
def app_startup():
|
|
connection_db = db_persistence_startup(table="wcag_validator_results")
|
|
if disclaim_bool_string(return_from_env_valid("USE_OPENAI_MODEL", "False")) == True:
|
|
openai_model = True
|
|
else:
|
|
openai_model = False
|
|
|
|
if openai_model:
|
|
mllm_end_point = return_from_env_valid("MLLM_END_POINT_OPENAI", "")
|
|
mllm_api_key = return_from_env_valid("MLLM_API_KEY_OPENAI", "")
|
|
mllm_model_id = return_from_env_valid("MLLM_MODEL_ID_OPENAI", "")
|
|
else:
|
|
mllm_end_point = return_from_env_valid("MLLM_END_POINT_LOCAL", "")
|
|
mllm_api_key = return_from_env_valid("MLLM_API_KEY_LOCAL", "")
|
|
mllm_model_id = return_from_env_valid("MLLM_MODEL_ID_LOCAL", "")
|
|
|
|
print("mllm_end_point:", mllm_end_point)
|
|
print("mllm_model_id:", mllm_model_id)
|
|
mllm_model_settings = {
|
|
"openai_model": openai_model,
|
|
"mllm_end_point": mllm_end_point,
|
|
"mllm_api_key": mllm_api_key,
|
|
"mllm_model_id": mllm_model_id,
|
|
}
|
|
|
|
return connection_db, mllm_model_settings
|
|
|
|
|
|
def run_server(
|
|
connection_db,
|
|
mllm_settings,
|
|
host,
|
|
port,
|
|
):
|
|
app = server(connection_db, mllm_settings)
|
|
uvicorn.run(app, host=host, port=port)
|
|
|
|
|
|
def cli(sys_argv):
|
|
# Import environment variables
|
|
env_path = find_dotenv(filename=".env")
|
|
|
|
if env_path == "":
|
|
print(
|
|
"rest server env path not found: service starting with the default params values"
|
|
)
|
|
|
|
_ = load_dotenv(env_path) # read .env file
|
|
|
|
connection_db, mllm_settings = app_startup()
|
|
|
|
default_rest_port = return_from_env_valid("rest_port", 8000)
|
|
default_rest_host = return_from_env_valid("rest_host", "0.0.0.0")
|
|
|
|
parser = argparse.ArgumentParser()
|
|
|
|
# ----------------
|
|
# Server parameters
|
|
# ----------------
|
|
parser.add_argument(
|
|
"-p",
|
|
"--port",
|
|
help="port for server (default: os.environ[REST_PORT] or 8000)",
|
|
default=default_rest_port,
|
|
type=int,
|
|
)
|
|
|
|
parser.add_argument(
|
|
"-H",
|
|
"--host",
|
|
help="host for server (default: os.environ[REST_HOST] or 0.0.0.0)",
|
|
default=default_rest_host,
|
|
)
|
|
|
|
args = parser.parse_args(sys_argv)
|
|
logging.info("service started with cli args:%s", args)
|
|
|
|
run_server(
|
|
connection_db,
|
|
mllm_settings,
|
|
args.host,
|
|
args.port,
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
cli(sys.argv[1:])
|