|
|
import httpx |
|
|
from fastapi import FastAPI, Request, HTTPException |
|
|
from starlette.responses import StreamingResponse, JSONResponse |
|
|
from starlette.background import BackgroundTask |
|
|
import os |
|
|
import random |
|
|
import logging |
|
|
import time |
|
|
import json |
|
|
from contextlib import asynccontextmanager |
|
|
|
|
|
|
|
|
LOG_LEVEL = os.getenv("LOG_LEVEL", "INFO").upper() |
|
|
logging.basicConfig( |
|
|
level=LOG_LEVEL, |
|
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' |
|
|
) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
ARTIFACT_URL = os.getenv("ARTIFACT_URL", "https://console.gmicloud.ai/api/v1/ie/artifact/get_public_artifacts") |
|
|
|
|
|
|
|
|
MAX_RETRIES = int(os.getenv("MAX_RETRIES", "5")) |
|
|
DEFAULT_RETRY_CODES = "429,500,502,503,504" |
|
|
RETRY_CODES_STR = os.getenv("RETRY_CODES", DEFAULT_RETRY_CODES) |
|
|
try: |
|
|
RETRY_STATUS_CODES = {int(code.strip()) for code in RETRY_CODES_STR.split(',')} |
|
|
logger.info(f"Will retry on the following status codes: {RETRY_STATUS_CODES}") |
|
|
except ValueError: |
|
|
logger.error(f"Invalid RETRY_CODES format: '{RETRY_CODES_STR}'. Falling back to default: {DEFAULT_RETRY_CODES}") |
|
|
RETRY_STATUS_CODES = {int(code.strip()) for code in DEFAULT_RETRY_CODES.split(',')} |
|
|
|
|
|
|
|
|
|
|
|
def generate_random_ip(): |
|
|
"""Generates a random, valid-looking IPv4 address.""" |
|
|
return ".".join(str(random.randint(1, 254)) for _ in range(4)) |
|
|
|
|
|
async def fetch_and_cache_models(app: FastAPI): |
|
|
""" |
|
|
Fetches the list of public artifacts and caches a routing table. |
|
|
This runs once on application startup. |
|
|
""" |
|
|
logger.info(f"Fetching model artifacts from: {ARTIFACT_URL}") |
|
|
model_routing_table = {} |
|
|
try: |
|
|
async with httpx.AsyncClient() as client: |
|
|
response = await client.get(ARTIFACT_URL, timeout=30.0) |
|
|
response.raise_for_status() |
|
|
artifacts = response.json() |
|
|
|
|
|
for artifact in artifacts: |
|
|
model_name = artifact.get("artifact_metadata", {}).get("artifact_name") |
|
|
endpoints = artifact.get("endpoints", []) |
|
|
|
|
|
|
|
|
if model_name and endpoints: |
|
|
|
|
|
|
|
|
endpoint_url = endpoints[0].get("endpoint_url") |
|
|
if endpoint_url: |
|
|
model_routing_table[model_name] = endpoint_url |
|
|
|
|
|
if not model_routing_table: |
|
|
logger.warning("No active model endpoints found from artifact URL.") |
|
|
else: |
|
|
logger.info(f"Successfully loaded {len(model_routing_table)} active models.") |
|
|
for name, url in model_routing_table.items(): |
|
|
logger.debug(f" - Model: '{name}' -> Endpoint: '{url}'") |
|
|
|
|
|
except httpx.RequestError as e: |
|
|
logger.critical(f"Failed to fetch model artifacts on startup: {e}") |
|
|
|
|
|
|
|
|
except Exception as e: |
|
|
logger.critical(f"An unexpected error occurred during model fetching: {e}") |
|
|
|
|
|
app.state.model_routing_table = model_routing_table |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@asynccontextmanager |
|
|
async def lifespan(app: FastAPI): |
|
|
"""Manages the app's lifecycle for startup and shutdown.""" |
|
|
|
|
|
|
|
|
async with httpx.AsyncClient(timeout=None) as client: |
|
|
app.state.http_client = client |
|
|
|
|
|
await fetch_and_cache_models(app) |
|
|
yield |
|
|
logger.info("Application shutdown complete.") |
|
|
|
|
|
|
|
|
app = FastAPI(docs_url=None, redoc_url=None, lifespan=lifespan) |
|
|
|
|
|
|
|
|
|
|
|
@app.get("/") |
|
|
async def health_check(): |
|
|
"""Provides a basic health check endpoint.""" |
|
|
return JSONResponse({ |
|
|
"status": "ok", |
|
|
"active_models": len(app.state.model_routing_table) |
|
|
}) |
|
|
|
|
|
@app.get("/v1/models") |
|
|
async def list_models(request: Request): |
|
|
""" |
|
|
Lists all available models discovered at startup. |
|
|
Formatted to be compatible with the OpenAI API. |
|
|
""" |
|
|
model_routing_table = request.app.state.model_routing_table |
|
|
model_list = [ |
|
|
{ |
|
|
"id": model_id, |
|
|
"object": "model", |
|
|
"created": int(time.time()), |
|
|
"owned_by": "gmi-serving", |
|
|
} |
|
|
for model_id in model_routing_table.keys() |
|
|
] |
|
|
return JSONResponse(content={"object": "list", "data": model_list}) |
|
|
|
|
|
|
|
|
@app.post("/v1/chat/completions") |
|
|
async def chat_completions_proxy(request: Request): |
|
|
""" |
|
|
Forwards chat completion requests to the correct model endpoint. |
|
|
""" |
|
|
start_time = time.monotonic() |
|
|
|
|
|
|
|
|
body = await request.body() |
|
|
try: |
|
|
data = json.loads(body) |
|
|
model_name = data.get("model") |
|
|
if not model_name: |
|
|
raise HTTPException(status_code=400, detail="Missing 'model' field in request body.") |
|
|
except json.JSONDecodeError: |
|
|
raise HTTPException(status_code=400, detail="Invalid JSON in request body.") |
|
|
|
|
|
model_routing_table = request.app.state.model_routing_table |
|
|
target_host = model_routing_table.get(model_name) |
|
|
|
|
|
if not target_host: |
|
|
raise HTTPException( |
|
|
status_code=404, |
|
|
detail=f"Model '{model_name}' not found or is not currently active." |
|
|
) |
|
|
|
|
|
|
|
|
client: httpx.AsyncClient = request.app.state.http_client |
|
|
|
|
|
|
|
|
target_url = f"https://{target_host}{request.url.path}" |
|
|
|
|
|
request_headers = dict(request.headers) |
|
|
request_headers.pop("host", None) |
|
|
|
|
|
random_ip = generate_random_ip() |
|
|
spoof_headers = { |
|
|
"user-agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/141.0.0.0 Safari/537.36", |
|
|
"x-forwarded-for": random_ip, |
|
|
"x-real-ip": random_ip, |
|
|
} |
|
|
request_headers.update(spoof_headers) |
|
|
|
|
|
logger.info( |
|
|
f"Routing request for model '{model_name}' to {target_url} " |
|
|
f"(Client: '{request.client.host}', Spoofed IP: {random_ip})" |
|
|
) |
|
|
|
|
|
|
|
|
last_exception = None |
|
|
for attempt in range(MAX_RETRIES): |
|
|
try: |
|
|
rp_req = client.build_request( |
|
|
method=request.method, url=target_url, headers=request_headers, content=body |
|
|
) |
|
|
rp_resp = await client.send(rp_req, stream=True) |
|
|
|
|
|
|
|
|
if rp_resp.status_code not in RETRY_STATUS_CODES or attempt == MAX_RETRIES - 1: |
|
|
duration_ms = (time.monotonic() - start_time) * 1000 |
|
|
log_func = logger.info if rp_resp.is_success else logger.warning |
|
|
log_func(f"Request finished for '{model_name}': {request.method} {request.url.path} status_code={rp_resp.status_code} latency={duration_ms:.2f}ms") |
|
|
|
|
|
return StreamingResponse( |
|
|
rp_resp.aiter_raw(), |
|
|
status_code=rp_resp.status_code, |
|
|
headers=rp_resp.headers, |
|
|
background=BackgroundTask(rp_resp.aclose), |
|
|
) |
|
|
|
|
|
|
|
|
logger.warning( |
|
|
f"Attempt {attempt + 1}/{MAX_RETRIES} for '{model_name}' failed with status {rp_resp.status_code}. Retrying..." |
|
|
) |
|
|
await rp_resp.aclose() |
|
|
await asyncio.sleep(1 * (2 ** attempt)) |
|
|
|
|
|
except httpx.ConnectError as e: |
|
|
last_exception = e |
|
|
logger.warning(f"Attempt {attempt + 1}/{MAX_RETRIES} for '{model_name}' failed with connection error: {e}") |
|
|
|
|
|
except Exception as e: |
|
|
last_exception = e |
|
|
logger.error(f"An unexpected error occurred during request forwarding: {e}") |
|
|
break |
|
|
|
|
|
|
|
|
duration_ms = (time.monotonic() - start_time) * 1000 |
|
|
logger.critical(f"Request failed for model '{model_name}' after {MAX_RETRIES} attempts. Cannot connect to target: {target_url}. Latency: {duration_ms:.2f}ms") |
|
|
|
|
|
raise HTTPException( |
|
|
status_code=502, |
|
|
detail=f"Bad Gateway: Cannot connect to model backend for '{model_name}' after {MAX_RETRIES} attempts. Last error: {last_exception}" |
|
|
) |