gabrielaltay's picture
Upload app.py with huggingface_hub
46f474a verified
#!/usr/bin/env python3
"""
ChromaDB Auth Proxy (robust passthrough)
- Bearer auth at the edge
- Streams/buffers appropriately
- Preserves Content-Type, avoids JSON re-serialization
- Reuses a single AsyncClient (HTTP/2, pooled)
- Filters hop-by-hop headers
- Maps network errors to 502/504
"""
import asyncio
import logging
import os
import time
from contextlib import asynccontextmanager
from typing import AsyncGenerator, Dict
import httpx
from fastapi import Depends, FastAPI, HTTPException, Request
from fastapi.responses import Response, StreamingResponse
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
import uvicorn
# Configure logging
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger(__name__)
# -------------------------
# Configuration
# -------------------------
CHROMA_HOST = os.getenv("CHROMA_HOST", "localhost")
CHROMA_PORT = int(os.getenv("CHROMA_PORT", "8001"))
PROXY_PORT = int(os.getenv("PROXY_PORT", "7860"))
AUTH_TOKEN = os.getenv("CHROMA_AUTH_TOKEN", "test_token_123")
# Timeout configuration (in seconds)
TIMEOUT_CONNECT = 10.0
TIMEOUT_READ = 60.0 * 8
TIMEOUT_WRITE = 60.0 * 2
TIMEOUT_POOL = None
# -------------------------
# Security
# -------------------------
security = HTTPBearer()
# -------------------------
# Lifespan management
# -------------------------
@asynccontextmanager
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
"""Manage application lifespan - startup and shutdown"""
logger.info("πŸš€ Starting ChromaDB Auth Proxy lifespan")
yield
logger.info("πŸ›‘ Shutting down ChromaDB Auth Proxy")
await _client.aclose()
app = FastAPI(title="ChromaDB Auth Proxy", lifespan=lifespan)
@app.get("/")
async def root():
return {"status": "ok", "service": "chromadb-auth-proxy"}
@app.get("/health")
async def health():
return {"status": "healthy", "service": "chromadb-auth-proxy"}
async def verify_token(credentials: HTTPAuthorizationCredentials = Depends(security)):
if credentials.credentials != AUTH_TOKEN:
raise HTTPException(
status_code=401,
detail="Invalid authentication token",
headers={"WWW-Authenticate": "Bearer"},
)
return credentials
# -------------------------
# HTTP client (shared)
# -------------------------
# Increased timeouts for large operations (collection deletion with 200k+ docs)
_client = httpx.AsyncClient(
http2=True,
timeout=httpx.Timeout(
connect=TIMEOUT_CONNECT,
read=TIMEOUT_READ,
write=TIMEOUT_WRITE,
pool=TIMEOUT_POOL,
),
limits=httpx.Limits(max_keepalive_connections=20, max_connections=100),
)
# Hop-by-hop headers we should not forward
HOP_BY_HOP = {
"connection",
"keep-alive",
"proxy-authenticate",
"proxy-authorization",
"te",
"trailer",
"transfer-encoding",
"upgrade",
}
# Response headers to pass through (allow-list)
PASS_HEADERS = {
"content-type",
"cache-control",
"etag",
"last-modified",
"expires",
"vary",
"location",
"content-disposition",
"content-encoding",
"x-chroma-trace-id",
}
def _filter_resp_headers(upstream: httpx.Response) -> Dict[str, str]:
"""Drop hop-by-hop and computed headers; keep useful ones."""
out: Dict[str, str] = {}
for k, v in upstream.headers.items():
kl = k.lower()
if kl in HOP_BY_HOP:
continue
if kl in PASS_HEADERS:
out[k] = v
return out
@app.api_route(
"/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "PATCH", "HEAD", "OPTIONS"]
)
async def proxy_request(request: Request, path: str, _=Depends(verify_token)):
start_time = time.time()
target_url = f"http://{CHROMA_HOST}:{CHROMA_PORT}/{path}"
# Special logging for DELETE operations
if request.method == "DELETE":
logger.warning(
f"⚠️ DELETE operation - may take up to {int(TIMEOUT_READ)}s for large collections"
)
# Query params
params = dict(request.query_params)
# Forward headers except host & auth
fwd_headers = {}
for k, v in request.headers.items():
kl = k.lower()
if kl in ("host", "authorization"):
continue
fwd_headers[k] = v
# Only read body for write-ish methods
body = None
body_size = 0
if request.method in {"POST", "PUT", "PATCH"}:
body = await request.body()
body_size = len(body)
logger.info(f" Request body size: {body_size} bytes")
try:
upstream_start = time.time()
async with _client.stream(
method=request.method,
url=target_url,
params=params,
headers=fwd_headers,
content=body,
) as upstream:
upstream_time = time.time() - upstream_start
status = upstream.status_code
resp_headers = _filter_resp_headers(upstream)
logger.info(
f" βœ… Upstream response: {status} (took {upstream_time:.2f}s)"
)
# HEAD / 204: no body
if request.method == "HEAD" or status == 204:
total_time = time.time() - start_time
logger.info(
f" πŸ“€ Returning HEAD/204 response (total: {total_time:.2f}s)"
)
return Response(status_code=status, headers=resp_headers)
ctype = upstream.headers.get("content-type", "")
# If JSON, buffer minimally and pass through bytes unchanged
if ctype.startswith("application/json"):
json_start = time.time()
data = await upstream.aread()
json_time = time.time() - json_start
total_time = time.time() - start_time
logger.info(
f" πŸ“€ Returning JSON response: {len(data)} bytes (json: {json_time:.2f}s, total: {total_time:.2f}s)"
)
return Response(
content=data,
status_code=status,
headers=resp_headers,
media_type=ctype,
)
# Otherwise stream raw chunks
async def _aiter():
chunk_count = 0
total_bytes = 0
async for chunk in upstream.aiter_raw():
if chunk:
chunk_count += 1
total_bytes += len(chunk)
yield chunk
# be nice to the event loop
await asyncio.sleep(0)
logger.info(f" πŸ“€ Streamed {chunk_count} chunks, {total_bytes} bytes")
return StreamingResponse(
_aiter(),
status_code=status,
headers=resp_headers,
media_type=ctype or None,
)
except httpx.ConnectTimeout:
total_time = time.time() - start_time
logger.error(f" ❌ Connect timeout after {total_time:.2f}s")
raise HTTPException(status_code=504, detail="Chroma upstream connect timeout")
except httpx.ReadTimeout:
total_time = time.time() - start_time
logger.error(f" ❌ Read timeout after {total_time:.2f}s")
raise HTTPException(status_code=504, detail="Chroma upstream read timeout")
except httpx.ConnectError as e:
total_time = time.time() - start_time
logger.error(f" ❌ Connect error after {total_time:.2f}s: {e}")
raise HTTPException(
status_code=502, detail=f"Chroma upstream connect error: {e}"
)
except httpx.TransportError as e:
total_time = time.time() - start_time
logger.error(f" ❌ Transport error after {total_time:.2f}s: {e}")
raise HTTPException(
status_code=502, detail=f"Chroma upstream transport error: {e}"
)
except Exception as e:
total_time = time.time() - start_time
logger.error(f" ❌ Unexpected error after {total_time:.2f}s: {e}")
raise HTTPException(status_code=500, detail=f"Internal proxy error: {e}")
if __name__ == "__main__":
print("πŸš€ Starting ChromaDB Auth Proxy")
print(f" Proxy URL: http://0.0.0.0:{PROXY_PORT}")
print(f" ChromaDB URL: http://{CHROMA_HOST}:{CHROMA_PORT}")
print(
f" Timeouts: connect={int(TIMEOUT_CONNECT)}s, read={int(TIMEOUT_READ)}s, write={int(TIMEOUT_WRITE)}s"
)
print(f" Logging: INFO level")
logger.info("ChromaDB Auth Proxy starting up")
uvicorn.run(app, host="0.0.0.0", port=PROXY_PORT)