Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| """ | |
| ChromaDB Auth Proxy (robust passthrough) | |
| - Bearer auth at the edge | |
| - Streams/buffers appropriately | |
| - Preserves Content-Type, avoids JSON re-serialization | |
| - Reuses a single AsyncClient (HTTP/2, pooled) | |
| - Filters hop-by-hop headers | |
| - Maps network errors to 502/504 | |
| """ | |
| import asyncio | |
| import logging | |
| import os | |
| import time | |
| from contextlib import asynccontextmanager | |
| from typing import AsyncGenerator, Dict | |
| import httpx | |
| from fastapi import Depends, FastAPI, HTTPException, Request | |
| from fastapi.responses import Response, StreamingResponse | |
| from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer | |
| import uvicorn | |
| # Configure logging | |
| logging.basicConfig( | |
| level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" | |
| ) | |
| logger = logging.getLogger(__name__) | |
| # ------------------------- | |
| # Configuration | |
| # ------------------------- | |
| CHROMA_HOST = os.getenv("CHROMA_HOST", "localhost") | |
| CHROMA_PORT = int(os.getenv("CHROMA_PORT", "8001")) | |
| PROXY_PORT = int(os.getenv("PROXY_PORT", "7860")) | |
| AUTH_TOKEN = os.getenv("CHROMA_AUTH_TOKEN", "test_token_123") | |
| # Timeout configuration (in seconds) | |
| TIMEOUT_CONNECT = 10.0 | |
| TIMEOUT_READ = 60.0 * 8 | |
| TIMEOUT_WRITE = 60.0 * 2 | |
| TIMEOUT_POOL = None | |
| # ------------------------- | |
| # Security | |
| # ------------------------- | |
| security = HTTPBearer() | |
| # ------------------------- | |
| # Lifespan management | |
| # ------------------------- | |
| async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: | |
| """Manage application lifespan - startup and shutdown""" | |
| logger.info("π Starting ChromaDB Auth Proxy lifespan") | |
| yield | |
| logger.info("π Shutting down ChromaDB Auth Proxy") | |
| await _client.aclose() | |
| app = FastAPI(title="ChromaDB Auth Proxy", lifespan=lifespan) | |
| async def root(): | |
| return {"status": "ok", "service": "chromadb-auth-proxy"} | |
| async def health(): | |
| return {"status": "healthy", "service": "chromadb-auth-proxy"} | |
| async def verify_token(credentials: HTTPAuthorizationCredentials = Depends(security)): | |
| if credentials.credentials != AUTH_TOKEN: | |
| raise HTTPException( | |
| status_code=401, | |
| detail="Invalid authentication token", | |
| headers={"WWW-Authenticate": "Bearer"}, | |
| ) | |
| return credentials | |
| # ------------------------- | |
| # HTTP client (shared) | |
| # ------------------------- | |
| # Increased timeouts for large operations (collection deletion with 200k+ docs) | |
| _client = httpx.AsyncClient( | |
| http2=True, | |
| timeout=httpx.Timeout( | |
| connect=TIMEOUT_CONNECT, | |
| read=TIMEOUT_READ, | |
| write=TIMEOUT_WRITE, | |
| pool=TIMEOUT_POOL, | |
| ), | |
| limits=httpx.Limits(max_keepalive_connections=20, max_connections=100), | |
| ) | |
| # Hop-by-hop headers we should not forward | |
| HOP_BY_HOP = { | |
| "connection", | |
| "keep-alive", | |
| "proxy-authenticate", | |
| "proxy-authorization", | |
| "te", | |
| "trailer", | |
| "transfer-encoding", | |
| "upgrade", | |
| } | |
| # Response headers to pass through (allow-list) | |
| PASS_HEADERS = { | |
| "content-type", | |
| "cache-control", | |
| "etag", | |
| "last-modified", | |
| "expires", | |
| "vary", | |
| "location", | |
| "content-disposition", | |
| "content-encoding", | |
| "x-chroma-trace-id", | |
| } | |
| def _filter_resp_headers(upstream: httpx.Response) -> Dict[str, str]: | |
| """Drop hop-by-hop and computed headers; keep useful ones.""" | |
| out: Dict[str, str] = {} | |
| for k, v in upstream.headers.items(): | |
| kl = k.lower() | |
| if kl in HOP_BY_HOP: | |
| continue | |
| if kl in PASS_HEADERS: | |
| out[k] = v | |
| return out | |
| async def proxy_request(request: Request, path: str, _=Depends(verify_token)): | |
| start_time = time.time() | |
| target_url = f"http://{CHROMA_HOST}:{CHROMA_PORT}/{path}" | |
| # Special logging for DELETE operations | |
| if request.method == "DELETE": | |
| logger.warning( | |
| f"β οΈ DELETE operation - may take up to {int(TIMEOUT_READ)}s for large collections" | |
| ) | |
| # Query params | |
| params = dict(request.query_params) | |
| # Forward headers except host & auth | |
| fwd_headers = {} | |
| for k, v in request.headers.items(): | |
| kl = k.lower() | |
| if kl in ("host", "authorization"): | |
| continue | |
| fwd_headers[k] = v | |
| # Only read body for write-ish methods | |
| body = None | |
| body_size = 0 | |
| if request.method in {"POST", "PUT", "PATCH"}: | |
| body = await request.body() | |
| body_size = len(body) | |
| logger.info(f" Request body size: {body_size} bytes") | |
| try: | |
| upstream_start = time.time() | |
| async with _client.stream( | |
| method=request.method, | |
| url=target_url, | |
| params=params, | |
| headers=fwd_headers, | |
| content=body, | |
| ) as upstream: | |
| upstream_time = time.time() - upstream_start | |
| status = upstream.status_code | |
| resp_headers = _filter_resp_headers(upstream) | |
| logger.info( | |
| f" β Upstream response: {status} (took {upstream_time:.2f}s)" | |
| ) | |
| # HEAD / 204: no body | |
| if request.method == "HEAD" or status == 204: | |
| total_time = time.time() - start_time | |
| logger.info( | |
| f" π€ Returning HEAD/204 response (total: {total_time:.2f}s)" | |
| ) | |
| return Response(status_code=status, headers=resp_headers) | |
| ctype = upstream.headers.get("content-type", "") | |
| # If JSON, buffer minimally and pass through bytes unchanged | |
| if ctype.startswith("application/json"): | |
| json_start = time.time() | |
| data = await upstream.aread() | |
| json_time = time.time() - json_start | |
| total_time = time.time() - start_time | |
| logger.info( | |
| f" π€ Returning JSON response: {len(data)} bytes (json: {json_time:.2f}s, total: {total_time:.2f}s)" | |
| ) | |
| return Response( | |
| content=data, | |
| status_code=status, | |
| headers=resp_headers, | |
| media_type=ctype, | |
| ) | |
| # Otherwise stream raw chunks | |
| async def _aiter(): | |
| chunk_count = 0 | |
| total_bytes = 0 | |
| async for chunk in upstream.aiter_raw(): | |
| if chunk: | |
| chunk_count += 1 | |
| total_bytes += len(chunk) | |
| yield chunk | |
| # be nice to the event loop | |
| await asyncio.sleep(0) | |
| logger.info(f" π€ Streamed {chunk_count} chunks, {total_bytes} bytes") | |
| return StreamingResponse( | |
| _aiter(), | |
| status_code=status, | |
| headers=resp_headers, | |
| media_type=ctype or None, | |
| ) | |
| except httpx.ConnectTimeout: | |
| total_time = time.time() - start_time | |
| logger.error(f" β Connect timeout after {total_time:.2f}s") | |
| raise HTTPException(status_code=504, detail="Chroma upstream connect timeout") | |
| except httpx.ReadTimeout: | |
| total_time = time.time() - start_time | |
| logger.error(f" β Read timeout after {total_time:.2f}s") | |
| raise HTTPException(status_code=504, detail="Chroma upstream read timeout") | |
| except httpx.ConnectError as e: | |
| total_time = time.time() - start_time | |
| logger.error(f" β Connect error after {total_time:.2f}s: {e}") | |
| raise HTTPException( | |
| status_code=502, detail=f"Chroma upstream connect error: {e}" | |
| ) | |
| except httpx.TransportError as e: | |
| total_time = time.time() - start_time | |
| logger.error(f" β Transport error after {total_time:.2f}s: {e}") | |
| raise HTTPException( | |
| status_code=502, detail=f"Chroma upstream transport error: {e}" | |
| ) | |
| except Exception as e: | |
| total_time = time.time() - start_time | |
| logger.error(f" β Unexpected error after {total_time:.2f}s: {e}") | |
| raise HTTPException(status_code=500, detail=f"Internal proxy error: {e}") | |
| if __name__ == "__main__": | |
| print("π Starting ChromaDB Auth Proxy") | |
| print(f" Proxy URL: http://0.0.0.0:{PROXY_PORT}") | |
| print(f" ChromaDB URL: http://{CHROMA_HOST}:{CHROMA_PORT}") | |
| print( | |
| f" Timeouts: connect={int(TIMEOUT_CONNECT)}s, read={int(TIMEOUT_READ)}s, write={int(TIMEOUT_WRITE)}s" | |
| ) | |
| print(f" Logging: INFO level") | |
| logger.info("ChromaDB Auth Proxy starting up") | |
| uvicorn.run(app, host="0.0.0.0", port=PROXY_PORT) | |