rkihacker commited on
Commit
daa63f8
·
verified ·
1 Parent(s): 419de53

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +49 -2
main.py CHANGED
@@ -7,6 +7,7 @@ import random
7
  import logging
8
  import time
9
  from contextlib import asynccontextmanager
 
10
 
11
  # --- Production-Ready Configuration ---
12
  LOG_LEVEL = os.getenv("LOG_LEVEL", "INFO").upper()
@@ -26,11 +27,56 @@ except ValueError:
26
  logging.error(f"Invalid RETRY_CODES format: '{RETRY_CODES_STR}'. Falling back to default: {DEFAULT_RETRY_CODES}")
27
  RETRY_STATUS_CODES = {int(code.strip()) for code in DEFAULT_RETRY_CODES.split(',')}
28
 
29
- # --- Helper Function ---
 
30
  def generate_random_ip():
31
  """Generates a random, valid-looking IPv4 address."""
32
  return ".".join(str(random.randint(1, 254)) for _ in range(4))
33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  # --- HTTPX Client Lifecycle Management ---
35
  @asynccontextmanager
36
  async def lifespan(app: FastAPI):
@@ -97,7 +143,8 @@ async def reverse_proxy_handler(request: Request):
97
  log_func(f"Request finished: {request.method} {request.url.path} status_code={rp_resp.status_code} latency={duration_ms:.2f}ms")
98
 
99
  return StreamingResponse(
100
- rp_resp.aiter_raw(),
 
101
  status_code=rp_resp.status_code,
102
  headers=rp_resp.headers,
103
  background=BackgroundTask(rp_resp.aclose),
 
7
  import logging
8
  import time
9
  from contextlib import asynccontextmanager
10
+ import json
11
 
12
  # --- Production-Ready Configuration ---
13
  LOG_LEVEL = os.getenv("LOG_LEVEL", "INFO").upper()
 
27
  logging.error(f"Invalid RETRY_CODES format: '{RETRY_CODES_STR}'. Falling back to default: {DEFAULT_RETRY_CODES}")
28
  RETRY_STATUS_CODES = {int(code.strip()) for code in DEFAULT_RETRY_CODES.split(',')}
29
 
30
+ # --- Helper Functions ---
31
+
32
  def generate_random_ip():
33
  """Generates a random, valid-looking IPv4 address."""
34
  return ".".join(str(random.randint(1, 254)) for _ in range(4))
35
 
36
+ async def modified_aiter_raw(original_aiter):
37
+ """
38
+ An async generator that intercepts and modifies the streaming data chunks.
39
+ It adds a prefix to the 'id' and includes a 'provider' field.
40
+ """
41
+ buffer = ""
42
+ async for chunk in original_aiter:
43
+ buffer += chunk.decode('utf-8')
44
+ while '\n' in buffer:
45
+ line, buffer = buffer.split('\n', 1)
46
+ if line.startswith('data:'):
47
+ try:
48
+ # Strip the "data: " prefix to get the JSON string
49
+ json_str = line[len('data: '):].strip()
50
+
51
+ # Process only if it's not the SSE termination message
52
+ if json_str and json_str != '[DONE]':
53
+ data = json.loads(json_str)
54
+
55
+ # Add 'NAI-' prefix to the id
56
+ if 'id' in data:
57
+ data['id'] = f"NAI-{data['id']}"
58
+
59
+ # Add the provider field
60
+ data['provider'] = 'TypeGPT'
61
+
62
+ # Reconstruct the SSE data line
63
+ modified_line = f"data: {json.dumps(data)}"
64
+ yield (modified_line + '\n').encode('utf-8')
65
+ else:
66
+ # Pass through messages like 'data: [DONE]'
67
+ yield (line + '\n').encode('utf-8')
68
+ except json.JSONDecodeError:
69
+ # If it's not valid JSON, pass it through as is
70
+ yield (line + '\n').encode('utf-8')
71
+ else:
72
+ # Pass through non-data lines (e.g., empty lines, comments)
73
+ yield (line + '\n').encode('utf-8')
74
+
75
+ # Yield any remaining data in the buffer
76
+ if buffer:
77
+ yield buffer.encode('utf-8')
78
+
79
+
80
  # --- HTTPX Client Lifecycle Management ---
81
  @asynccontextmanager
82
  async def lifespan(app: FastAPI):
 
143
  log_func(f"Request finished: {request.method} {request.url.path} status_code={rp_resp.status_code} latency={duration_ms:.2f}ms")
144
 
145
  return StreamingResponse(
146
+ # Use the new async generator to modify the stream
147
+ modified_aiter_raw(rp_resp.aiter_raw()),
148
  status_code=rp_resp.status_code,
149
  headers=rp_resp.headers,
150
  background=BackgroundTask(rp_resp.aclose),