Spaces:
Sleeping
Sleeping
from fastapi import FastAPI, Request | |
import httpx | |
from starlette.responses import StreamingResponse, JSONResponse | |
from starlette.background import BackgroundTask | |
import uvicorn | |
import json | |
app = FastAPI(debug=True) | |
# Define the base URL of your backend server | |
BACKEND_BASE_URL = "http://localhost:8000" | |
TIMEOUT_KEEP_ALIVE = 5.0 | |
timeout_config = httpx.Timeout(5.0, read=60.0) | |
async def hook(response: httpx.Response) -> None: | |
if response.is_error: | |
await response.aread() | |
response.raise_for_status() | |
async def forward_get_request(path: str, request: Request): | |
async with httpx.AsyncClient() as client: | |
response = await client.get(f"{BACKEND_BASE_URL}/{path}", params=request.query_params) | |
content = response.aiter_bytes() if response.is_stream_consumed else response.content | |
return StreamingResponse(content, media_type=response.headers['Content-Type']) | |
async def forward_post_request(path: str, request: Request): | |
# Retrieve the request body | |
body = await request.body() | |
# Prepare the headers, excluding those that can cause issues | |
headers = {k: v for k, v in request.headers.items() if k.lower() not in ["host", "content-length"]} | |
async with httpx.AsyncClient(event_hooks={'response': [hook]}, timeout=timeout_config) as client: | |
# Send the request and get the response as a stream | |
req = client.build_request("POST", f"{BACKEND_BASE_URL}/{path}", content=body, headers=headers) | |
try: | |
response = await client.send(req, stream=True) | |
response.raise_for_status() | |
if json.loads(body.decode('utf-8'))['stream']: | |
# Custom streaming function | |
async def stream_response(response): | |
async for chunk in response.aiter_bytes(): | |
yield chunk | |
await response.aclose() # Ensure the response is closed after streaming | |
return StreamingResponse(stream_response(response), | |
status_code=response.status_code, | |
headers=headers) | |
else: # For regular JSON responses | |
# For non-streaming responses, read the complete response body | |
content = await response.aread() | |
return JSONResponse(content=content, status_code=response.status_code) | |
except httpx.ResponseNotRead as exc: | |
print(f"HTTP Exception for {exc.request.url} - {exc}") | |
if __name__ == "__main__": | |
uvicorn.run(app, | |
host='127.0.0.1', | |
port=7860, | |
log_level="debug", | |
timeout_keep_alive=TIMEOUT_KEEP_ALIVE) | |