|
|
import time |
|
|
from collections import defaultdict |
|
|
from fastapi import Request |
|
|
from fastapi.responses import JSONResponse |
|
|
from starlette.middleware.base import BaseHTTPMiddleware |
|
|
import logging |
|
|
import asyncio |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
class RateLimitMiddleware(BaseHTTPMiddleware): |
|
|
def __init__( |
|
|
self, |
|
|
app, |
|
|
rate_limit_per_minute=10, |
|
|
rate_limit_window=60, |
|
|
protected_routes=["/generate", "/api/generate", "/api/generate-with-report"] |
|
|
): |
|
|
super().__init__(app) |
|
|
self.rate_limit_per_minute = rate_limit_per_minute |
|
|
self.rate_limit_window = rate_limit_window |
|
|
self.protected_routes = protected_routes |
|
|
self.ip_requests = defaultdict(list) |
|
|
logger.info(f"Rate limit middleware initialized: {rate_limit_per_minute} requests per {rate_limit_window}s") |
|
|
|
|
|
async def dispatch(self, request: Request, call_next): |
|
|
client_ip = request.client.host |
|
|
current_time = time.time() |
|
|
|
|
|
|
|
|
if any(request.url.path.startswith(route) for route in self.protected_routes): |
|
|
|
|
|
self.ip_requests[client_ip] = [t for t in self.ip_requests[client_ip] |
|
|
if current_time - t < self.rate_limit_window] |
|
|
|
|
|
|
|
|
if len(self.ip_requests[client_ip]) >= self.rate_limit_per_minute: |
|
|
logger.warning(f"Rate limit exceeded for IP {client_ip} on {request.url.path}") |
|
|
return JSONResponse( |
|
|
status_code=429, |
|
|
content={"detail": "Rate limit exceeded. Please try again later."} |
|
|
) |
|
|
|
|
|
|
|
|
self.ip_requests[client_ip].append(current_time) |
|
|
|
|
|
|
|
|
response = await call_next(request) |
|
|
return response |
|
|
|
|
|
class ConcurrencyLimitMiddleware(BaseHTTPMiddleware): |
|
|
def __init__( |
|
|
self, |
|
|
app, |
|
|
max_concurrent_requests=5, |
|
|
timeout=5.0, |
|
|
protected_routes=None |
|
|
): |
|
|
super().__init__(app) |
|
|
self.semaphore = asyncio.Semaphore(max_concurrent_requests) |
|
|
self.timeout = timeout |
|
|
self.protected_routes = protected_routes or ["/generate", "/api/generate", "/api/generate-with-report"] |
|
|
logger.info(f"Concurrency limit middleware initialized: {max_concurrent_requests} concurrent requests") |
|
|
|
|
|
async def dispatch(self, request, call_next): |
|
|
try: |
|
|
|
|
|
if any(request.url.path.startswith(route) for route in self.protected_routes): |
|
|
try: |
|
|
|
|
|
acquired = False |
|
|
try: |
|
|
|
|
|
await asyncio.wait_for(self.semaphore.acquire(), timeout=self.timeout) |
|
|
acquired = True |
|
|
return await call_next(request) |
|
|
finally: |
|
|
if acquired: |
|
|
self.semaphore.release() |
|
|
except asyncio.TimeoutError: |
|
|
|
|
|
logger.warning(f"Concurrency limit reached for {request.url.path}") |
|
|
return JSONResponse( |
|
|
status_code=503, |
|
|
content={"detail": "Server is at capacity. Please try again later."} |
|
|
) |
|
|
else: |
|
|
|
|
|
return await call_next(request) |
|
|
except Exception as e: |
|
|
logger.error(f"Error in ConcurrencyLimitMiddleware: {str(e)}") |
|
|
return JSONResponse( |
|
|
status_code=500, |
|
|
content={"detail": f"Internal server error in middleware: {str(e)}"} |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
class RequestSizeLimitMiddleware(BaseHTTPMiddleware): |
|
|
def __init__(self, app, max_content_length=1024*1024): |
|
|
super().__init__(app) |
|
|
self.max_content_length = max_content_length |
|
|
logger.info(f"Request size limit middleware initialized: {max_content_length} bytes") |
|
|
|
|
|
async def dispatch(self, request: Request, call_next): |
|
|
content_length = request.headers.get('content-length') |
|
|
if content_length: |
|
|
if int(content_length) > self.max_content_length: |
|
|
logger.warning(f"Request too large: {content_length} bytes") |
|
|
return JSONResponse( |
|
|
status_code=413, |
|
|
content={"detail": "Request too large"} |
|
|
) |
|
|
return await call_next(request) |
|
|
|