daetheris's picture
Initial commit after cleanup
80e0598 verified
raw
history blame
4.92 kB
import time
from collections import defaultdict
from fastapi import Request
from fastapi.responses import JSONResponse
from starlette.middleware.base import BaseHTTPMiddleware
import logging
import asyncio # Concurrency limiting
logger = logging.getLogger(__name__)
class RateLimitMiddleware(BaseHTTPMiddleware):
def __init__(
self,
app,
rate_limit_per_minute=10,
rate_limit_window=60,
protected_routes=["/generate", "/api/generate", "/api/generate-with-report"]
):
super().__init__(app)
self.rate_limit_per_minute = rate_limit_per_minute
self.rate_limit_window = rate_limit_window
self.protected_routes = protected_routes
self.ip_requests = defaultdict(list)
logger.info(f"Rate limit middleware initialized: {rate_limit_per_minute} requests per {rate_limit_window}s")
async def dispatch(self, request: Request, call_next):
client_ip = request.client.host
current_time = time.time()
# Only apply rate limiting to protected routes
if any(request.url.path.startswith(route) for route in self.protected_routes):
# Clean up old requests
self.ip_requests[client_ip] = [t for t in self.ip_requests[client_ip]
if current_time - t < self.rate_limit_window]
# Check if rate limit exceeded
if len(self.ip_requests[client_ip]) >= self.rate_limit_per_minute:
logger.warning(f"Rate limit exceeded for IP {client_ip} on {request.url.path}")
return JSONResponse(
status_code=429,
content={"detail": "Rate limit exceeded. Please try again later."}
)
# Add current request timestamp
self.ip_requests[client_ip].append(current_time)
# Process the request
response = await call_next(request)
return response
class ConcurrencyLimitMiddleware(BaseHTTPMiddleware):
def __init__(
self,
app,
max_concurrent_requests=5,
timeout=5.0,
protected_routes=None
):
super().__init__(app)
self.semaphore = asyncio.Semaphore(max_concurrent_requests)
self.timeout = timeout
self.protected_routes = protected_routes or ["/generate", "/api/generate", "/api/generate-with-report"]
logger.info(f"Concurrency limit middleware initialized: {max_concurrent_requests} concurrent requests")
async def dispatch(self, request, call_next):
try:
# Only apply to protected routes
if any(request.url.path.startswith(route) for route in self.protected_routes):
try:
# Try to acquire the semaphore
acquired = False
try:
# Use wait_for instead of timeout context manager for compatibility
await asyncio.wait_for(self.semaphore.acquire(), timeout=self.timeout)
acquired = True
return await call_next(request)
finally:
if acquired:
self.semaphore.release()
except asyncio.TimeoutError:
# Timeout waiting for semaphore
logger.warning(f"Concurrency limit reached for {request.url.path}")
return JSONResponse(
status_code=503,
content={"detail": "Server is at capacity. Please try again later."}
)
else:
# For non-protected routes, proceed normally
return await call_next(request)
except Exception as e:
logger.error(f"Error in ConcurrencyLimitMiddleware: {str(e)}")
return JSONResponse(
status_code=500,
content={"detail": f"Internal server error in middleware: {str(e)}"}
)
# Protection against large request payloads
class RequestSizeLimitMiddleware(BaseHTTPMiddleware):
def __init__(self, app, max_content_length=1024*1024): # 1MB default
super().__init__(app)
self.max_content_length = max_content_length
logger.info(f"Request size limit middleware initialized: {max_content_length} bytes")
async def dispatch(self, request: Request, call_next):
content_length = request.headers.get('content-length')
if content_length:
if int(content_length) > self.max_content_length:
logger.warning(f"Request too large: {content_length} bytes")
return JSONResponse(
status_code=413,
content={"detail": "Request too large"}
)
return await call_next(request)