flash / main.py
rkihacker's picture
Update main.py
61655b8 verified
raw
history blame
8.92 kB
import httpx
from fastapi import FastAPI, Request, HTTPException
from starlette.responses import StreamingResponse, JSONResponse
from starlette.background import BackgroundTask
import os
import random
import logging
import time
import json
from contextlib import asynccontextmanager
# --- Production-Ready Configuration ---
LOG_LEVEL = os.getenv("LOG_LEVEL", "INFO").upper()
logging.basicConfig(
level=LOG_LEVEL,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
# URL to fetch the list of all available models and their endpoints
ARTIFACT_URL = os.getenv("ARTIFACT_URL", "https://console.gmicloud.ai/api/v1/ie/artifact/get_public_artifacts")
# Retry logic configuration
MAX_RETRIES = int(os.getenv("MAX_RETRIES", "5"))
DEFAULT_RETRY_CODES = "429,500,502,503,504"
RETRY_CODES_STR = os.getenv("RETRY_CODES", DEFAULT_RETRY_CODES)
try:
RETRY_STATUS_CODES = {int(code.strip()) for code in RETRY_CODES_STR.split(',')}
logger.info(f"Will retry on the following status codes: {RETRY_STATUS_CODES}")
except ValueError:
logger.error(f"Invalid RETRY_CODES format: '{RETRY_CODES_STR}'. Falling back to default: {DEFAULT_RETRY_CODES}")
RETRY_STATUS_CODES = {int(code.strip()) for code in DEFAULT_RETRY_CODES.split(',')}
# --- Helper Functions ---
def generate_random_ip():
"""Generates a random, valid-looking IPv4 address."""
return ".".join(str(random.randint(1, 254)) for _ in range(4))
async def fetch_and_cache_models(app: FastAPI):
"""
Fetches the list of public artifacts and caches a routing table.
This runs once on application startup.
"""
logger.info(f"Fetching model artifacts from: {ARTIFACT_URL}")
model_routing_table = {}
try:
async with httpx.AsyncClient() as client:
response = await client.get(ARTIFACT_URL, timeout=30.0)
response.raise_for_status()
artifacts = response.json()
for artifact in artifacts:
model_name = artifact.get("artifact_metadata", {}).get("artifact_name")
endpoints = artifact.get("endpoints", [])
# We only care about models that have a running endpoint
if model_name and endpoints:
# A model could have multiple endpoints, we'll just use the first one
# A more advanced setup could load-balance between them
endpoint_url = endpoints[0].get("endpoint_url")
if endpoint_url:
model_routing_table[model_name] = endpoint_url
if not model_routing_table:
logger.warning("No active model endpoints found from artifact URL.")
else:
logger.info(f"Successfully loaded {len(model_routing_table)} active models.")
for name, url in model_routing_table.items():
logger.debug(f" - Model: '{name}' -> Endpoint: '{url}'")
except httpx.RequestError as e:
logger.critical(f"Failed to fetch model artifacts on startup: {e}")
# In a real-world scenario, you might want the app to fail starting
# or handle this more gracefully. For now, we start with an empty table.
except Exception as e:
logger.critical(f"An unexpected error occurred during model fetching: {e}")
app.state.model_routing_table = model_routing_table
# --- HTTPX Client Lifecycle Management ---
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Manages the app's lifecycle for startup and shutdown."""
# Create a single, long-lived HTTP client for forwarding requests
# No base_url as we will be calling different hosts dynamically
async with httpx.AsyncClient(timeout=None) as client:
app.state.http_client = client
# Fetch and cache model routes on startup
await fetch_and_cache_models(app)
yield
logger.info("Application shutdown complete.")
# Initialize the FastAPI app with the lifespan manager and disabled docs
app = FastAPI(docs_url=None, redoc_url=None, lifespan=lifespan)
# --- API Endpoints ---
@app.get("/")
async def health_check():
"""Provides a basic health check endpoint."""
return JSONResponse({
"status": "ok",
"active_models": len(app.state.model_routing_table)
})
@app.get("/v1/models")
async def list_models(request: Request):
"""
Lists all available models discovered at startup.
Formatted to be compatible with the OpenAI API.
"""
model_routing_table = request.app.state.model_routing_table
model_list = [
{
"id": model_id,
"object": "model",
"created": int(time.time()),
"owned_by": "gmi-serving",
}
for model_id in model_routing_table.keys()
]
return JSONResponse(content={"object": "list", "data": model_list})
@app.post("/v1/chat/completions")
async def chat_completions_proxy(request: Request):
"""
Forwards chat completion requests to the correct model endpoint.
"""
start_time = time.monotonic()
# --- 1. Get Model Name and Find Target Host ---
body = await request.body()
try:
data = json.loads(body)
model_name = data.get("model")
if not model_name:
raise HTTPException(status_code=400, detail="Missing 'model' field in request body.")
except json.JSONDecodeError:
raise HTTPException(status_code=400, detail="Invalid JSON in request body.")
model_routing_table = request.app.state.model_routing_table
target_host = model_routing_table.get(model_name)
if not target_host:
raise HTTPException(
status_code=404,
detail=f"Model '{model_name}' not found or is not currently active."
)
# --- 2. Prepare and Forward the Request ---
client: httpx.AsyncClient = request.app.state.http_client
# Construct the full URL to the backend service
target_url = f"https://{target_host}{request.url.path}"
request_headers = dict(request.headers)
request_headers.pop("host", None)
random_ip = generate_random_ip()
spoof_headers = {
"user-agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/141.0.0.0 Safari/537.36",
"x-forwarded-for": random_ip,
"x-real-ip": random_ip,
}
request_headers.update(spoof_headers)
logger.info(
f"Routing request for model '{model_name}' to {target_url} "
f"(Client: '{request.client.host}', Spoofed IP: {random_ip})"
)
# --- 3. Execute with Retry Logic ---
last_exception = None
for attempt in range(MAX_RETRIES):
try:
rp_req = client.build_request(
method=request.method, url=target_url, headers=request_headers, content=body
)
rp_resp = await client.send(rp_req, stream=True)
# If status is not retryable OR it's the last attempt, stream the response
if rp_resp.status_code not in RETRY_STATUS_CODES or attempt == MAX_RETRIES - 1:
duration_ms = (time.monotonic() - start_time) * 1000
log_func = logger.info if rp_resp.is_success else logger.warning
log_func(f"Request finished for '{model_name}': {request.method} {request.url.path} status_code={rp_resp.status_code} latency={duration_ms:.2f}ms")
return StreamingResponse(
rp_resp.aiter_raw(),
status_code=rp_resp.status_code,
headers=rp_resp.headers,
background=BackgroundTask(rp_resp.aclose),
)
# Otherwise, log and prepare for retry
logger.warning(
f"Attempt {attempt + 1}/{MAX_RETRIES} for '{model_name}' failed with status {rp_resp.status_code}. Retrying..."
)
await rp_resp.aclose() # Ensure the connection is closed before retrying
await asyncio.sleep(1 * (2 ** attempt)) # Exponential backoff
except httpx.ConnectError as e:
last_exception = e
logger.warning(f"Attempt {attempt + 1}/{MAX_RETRIES} for '{model_name}' failed with connection error: {e}")
except Exception as e:
last_exception = e
logger.error(f"An unexpected error occurred during request forwarding: {e}")
break # Don't retry on unexpected errors
# --- 4. Handle Final Failure ---
duration_ms = (time.monotonic() - start_time) * 1000
logger.critical(f"Request failed for model '{model_name}' after {MAX_RETRIES} attempts. Cannot connect to target: {target_url}. Latency: {duration_ms:.2f}ms")
raise HTTPException(
status_code=502,
detail=f"Bad Gateway: Cannot connect to model backend for '{model_name}' after {MAX_RETRIES} attempts. Last error: {last_exception}"
)