certifaier / aiproxy /async_proxy.py
bsmit1659's picture
changing to routing proxy
e2d4dfc
from fastapi import FastAPI, Request
import httpx
from starlette.responses import StreamingResponse, JSONResponse
from starlette.background import BackgroundTask
import uvicorn
import json
app = FastAPI(debug=True)
# Define the base URL of your backend server
BACKEND_BASE_URL = "http://localhost:8000"
TIMEOUT_KEEP_ALIVE = 5.0
timeout_config = httpx.Timeout(5.0, read=60.0)
async def hook(response: httpx.Response) -> None:
if response.is_error:
await response.aread()
response.raise_for_status()
@app.get("/{path:path}")
async def forward_get_request(path: str, request: Request):
async with httpx.AsyncClient() as client:
response = await client.get(f"{BACKEND_BASE_URL}/{path}", params=request.query_params)
content = response.aiter_bytes() if response.is_stream_consumed else response.content
return StreamingResponse(content, media_type=response.headers['Content-Type'])
@app.post("/{path:path}")
async def forward_post_request(path: str, request: Request):
# Retrieve the request body
body = await request.body()
# Prepare the headers, excluding those that can cause issues
headers = {k: v for k, v in request.headers.items() if k.lower() not in ["host", "content-length"]}
async with httpx.AsyncClient(event_hooks={'response': [hook]}, timeout=timeout_config) as client:
# Send the request and get the response as a stream
req = client.build_request("POST", f"{BACKEND_BASE_URL}/{path}", content=body, headers=headers)
try:
response = await client.send(req, stream=True)
response.raise_for_status()
if json.loads(body.decode('utf-8'))['stream']:
# Custom streaming function
async def stream_response(response):
async for chunk in response.aiter_bytes():
yield chunk
await response.aclose() # Ensure the response is closed after streaming
return StreamingResponse(stream_response(response),
status_code=response.status_code,
headers=headers)
else: # For regular JSON responses
# For non-streaming responses, read the complete response body
content = await response.aread()
return JSONResponse(content=content, status_code=response.status_code)
except httpx.ResponseNotRead as exc:
print(f"HTTP Exception for {exc.request.url} - {exc}")
if __name__ == "__main__":
uvicorn.run(app,
host='127.0.0.1',
port=7860,
log_level="debug",
timeout_keep_alive=TIMEOUT_KEEP_ALIVE)