Update main.py
Browse files
main.py
CHANGED
|
@@ -7,6 +7,7 @@ import random
|
|
| 7 |
import logging
|
| 8 |
import time
|
| 9 |
from contextlib import asynccontextmanager
|
|
|
|
| 10 |
|
| 11 |
# --- Production-Ready Configuration ---
|
| 12 |
LOG_LEVEL = os.getenv("LOG_LEVEL", "INFO").upper()
|
|
@@ -26,11 +27,56 @@ except ValueError:
|
|
| 26 |
logging.error(f"Invalid RETRY_CODES format: '{RETRY_CODES_STR}'. Falling back to default: {DEFAULT_RETRY_CODES}")
|
| 27 |
RETRY_STATUS_CODES = {int(code.strip()) for code in DEFAULT_RETRY_CODES.split(',')}
|
| 28 |
|
| 29 |
-
# --- Helper
|
|
|
|
| 30 |
def generate_random_ip():
|
| 31 |
"""Generates a random, valid-looking IPv4 address."""
|
| 32 |
return ".".join(str(random.randint(1, 254)) for _ in range(4))
|
| 33 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
# --- HTTPX Client Lifecycle Management ---
|
| 35 |
@asynccontextmanager
|
| 36 |
async def lifespan(app: FastAPI):
|
|
@@ -97,7 +143,8 @@ async def reverse_proxy_handler(request: Request):
|
|
| 97 |
log_func(f"Request finished: {request.method} {request.url.path} status_code={rp_resp.status_code} latency={duration_ms:.2f}ms")
|
| 98 |
|
| 99 |
return StreamingResponse(
|
| 100 |
-
|
|
|
|
| 101 |
status_code=rp_resp.status_code,
|
| 102 |
headers=rp_resp.headers,
|
| 103 |
background=BackgroundTask(rp_resp.aclose),
|
|
|
|
| 7 |
import logging
|
| 8 |
import time
|
| 9 |
from contextlib import asynccontextmanager
|
| 10 |
+
import json
|
| 11 |
|
| 12 |
# --- Production-Ready Configuration ---
|
| 13 |
LOG_LEVEL = os.getenv("LOG_LEVEL", "INFO").upper()
|
|
|
|
| 27 |
logging.error(f"Invalid RETRY_CODES format: '{RETRY_CODES_STR}'. Falling back to default: {DEFAULT_RETRY_CODES}")
|
| 28 |
RETRY_STATUS_CODES = {int(code.strip()) for code in DEFAULT_RETRY_CODES.split(',')}
|
| 29 |
|
| 30 |
+
# --- Helper Functions ---
|
| 31 |
+
|
| 32 |
def generate_random_ip():
|
| 33 |
"""Generates a random, valid-looking IPv4 address."""
|
| 34 |
return ".".join(str(random.randint(1, 254)) for _ in range(4))
|
| 35 |
|
| 36 |
+
async def modified_aiter_raw(original_aiter):
|
| 37 |
+
"""
|
| 38 |
+
An async generator that intercepts and modifies the streaming data chunks.
|
| 39 |
+
It adds a prefix to the 'id' and includes a 'provider' field.
|
| 40 |
+
"""
|
| 41 |
+
buffer = ""
|
| 42 |
+
async for chunk in original_aiter:
|
| 43 |
+
buffer += chunk.decode('utf-8')
|
| 44 |
+
while '\n' in buffer:
|
| 45 |
+
line, buffer = buffer.split('\n', 1)
|
| 46 |
+
if line.startswith('data:'):
|
| 47 |
+
try:
|
| 48 |
+
# Strip the "data: " prefix to get the JSON string
|
| 49 |
+
json_str = line[len('data: '):].strip()
|
| 50 |
+
|
| 51 |
+
# Process only if it's not the SSE termination message
|
| 52 |
+
if json_str and json_str != '[DONE]':
|
| 53 |
+
data = json.loads(json_str)
|
| 54 |
+
|
| 55 |
+
# Add 'NAI-' prefix to the id
|
| 56 |
+
if 'id' in data:
|
| 57 |
+
data['id'] = f"NAI-{data['id']}"
|
| 58 |
+
|
| 59 |
+
# Add the provider field
|
| 60 |
+
data['provider'] = 'TypeGPT'
|
| 61 |
+
|
| 62 |
+
# Reconstruct the SSE data line
|
| 63 |
+
modified_line = f"data: {json.dumps(data)}"
|
| 64 |
+
yield (modified_line + '\n').encode('utf-8')
|
| 65 |
+
else:
|
| 66 |
+
# Pass through messages like 'data: [DONE]'
|
| 67 |
+
yield (line + '\n').encode('utf-8')
|
| 68 |
+
except json.JSONDecodeError:
|
| 69 |
+
# If it's not valid JSON, pass it through as is
|
| 70 |
+
yield (line + '\n').encode('utf-8')
|
| 71 |
+
else:
|
| 72 |
+
# Pass through non-data lines (e.g., empty lines, comments)
|
| 73 |
+
yield (line + '\n').encode('utf-8')
|
| 74 |
+
|
| 75 |
+
# Yield any remaining data in the buffer
|
| 76 |
+
if buffer:
|
| 77 |
+
yield buffer.encode('utf-8')
|
| 78 |
+
|
| 79 |
+
|
| 80 |
# --- HTTPX Client Lifecycle Management ---
|
| 81 |
@asynccontextmanager
|
| 82 |
async def lifespan(app: FastAPI):
|
|
|
|
| 143 |
log_func(f"Request finished: {request.method} {request.url.path} status_code={rp_resp.status_code} latency={duration_ms:.2f}ms")
|
| 144 |
|
| 145 |
return StreamingResponse(
|
| 146 |
+
# Use the new async generator to modify the stream
|
| 147 |
+
modified_aiter_raw(rp_resp.aiter_raw()),
|
| 148 |
status_code=rp_resp.status_code,
|
| 149 |
headers=rp_resp.headers,
|
| 150 |
background=BackgroundTask(rp_resp.aclose),
|