Spaces:
Paused
Paused
| """CSRF Protection for web interface""" | |
| import hashlib | |
| import hmac | |
| import secrets | |
| from fastapi import HTTPException, Request | |
| from starlette.middleware.base import BaseHTTPMiddleware | |
| class CSRFProtection: | |
| """CSRF token generation and validation""" | |
| def __init__(self, secret_key: str): | |
| self.secret_key = secret_key.encode() | |
| def generate_token(self, session_id: str) -> str: | |
| """Generate CSRF token for session""" | |
| # Create token from session ID and random nonce | |
| nonce = secrets.token_hex(16) | |
| message = f"{session_id}:{nonce}".encode() | |
| # Sign with HMAC | |
| signature = hmac.new(self.secret_key, message, hashlib.sha256).hexdigest() | |
| return f"{nonce}.{signature}" | |
| def validate_token(self, token: str, session_id: str) -> bool: | |
| """Validate CSRF token""" | |
| try: | |
| nonce, signature = token.split(".") | |
| # Recreate signature | |
| message = f"{session_id}:{nonce}".encode() | |
| expected_signature = hmac.new( | |
| self.secret_key, message, hashlib.sha256 | |
| ).hexdigest() | |
| # Constant-time comparison | |
| return hmac.compare_digest(signature, expected_signature) | |
| except (ValueError, AttributeError): | |
| return False | |
| class CSRFMiddleware(BaseHTTPMiddleware): | |
| """Middleware to enforce CSRF protection""" | |
| def __init__(self, app, secret_key: str): | |
| super().__init__(app) | |
| self.csrf = CSRFProtection(secret_key) | |
| self.exempt_methods = {"GET", "HEAD", "OPTIONS", "TRACE"} | |
| self.exempt_paths = {"/auth/token", "/auth/register", "/health", "/metrics"} | |
| async def dispatch(self, request: Request, call_next): | |
| # Skip CSRF check for exempt methods and paths | |
| if request.method in self.exempt_methods: | |
| return await call_next(request) | |
| if any(request.url.path.startswith(path) for path in self.exempt_paths): | |
| return await call_next(request) | |
| # Get CSRF token from header | |
| csrf_token = request.headers.get("X-CSRF-Token") | |
| if not csrf_token: | |
| raise HTTPException(status_code=403, detail="CSRF token missing") | |
| # Get session ID from cookie or header | |
| session_id = request.cookies.get("session_id", "") | |
| if not self.csrf.validate_token(csrf_token, session_id): | |
| raise HTTPException(status_code=403, detail="Invalid CSRF token") | |
| response = await call_next(request) | |
| return response | |
| # Helper function to get CSRF token | |
| def get_csrf_token(secret_key: str, session_id: str) -> str: | |
| """Helper to generate CSRF token""" | |
| csrf = CSRFProtection(secret_key) | |
| return csrf.generate_token(session_id) | |