backend / app /main.py
bastienp's picture
feat(security): configure CORS and API validation
5873878
from fastapi import FastAPI, Request, HTTPException, Security, Depends
from fastapi.middleware.cors import CORSMiddleware
from app.routes import health, wagons, chat, players, generate
from app.core.logging import get_logger, setup_logging
from dotenv import load_dotenv
from datetime import datetime
import time
from pathlib import Path
from fastapi.security.api_key import APIKeyHeader
import os
import secrets
# Load environment variables
load_dotenv()
# Setup logging
logger = get_logger("main")
API_KEY = os.getenv("API_KEY")
api_key_header = APIKeyHeader(name="X-API-Key", auto_error=True)
# Add this near the top with other environment variables
FRONTEND_URL = os.getenv("FRONTEND_URL", "http://localhost:3000") # Add default for local development
async def get_api_key(api_key_header: str = Security(api_key_header)):
if not API_KEY:
logger.error("API key not configured on server")
raise HTTPException(
status_code=500,
detail="Server configuration error" # Don't expose specific details
)
if not api_key_header:
raise HTTPException(
status_code=401,
detail="Missing API key",
headers={"WWW-Authenticate": "ApiKey"},
)
if not secrets.compare_digest(api_key_header, API_KEY): # Constant-time comparison
logger.warning(f"Invalid API key attempt from {request.client.host}")
raise HTTPException(
status_code=403,
detail="Invalid authentication credentials",
headers={"WWW-Authenticate": "ApiKey"},
)
return api_key_header
app = FastAPI(
title="Game Jam API",
description="API for Game Jam Hackathon",
version="1.0.0",
dependencies=[Depends(get_api_key)],
)
app.add_middleware(
CORSMiddleware,
allow_origins=[FRONTEND_URL], # Replace "*" with specific frontend URL
allow_credentials=True, # Changed to True since we're restricting origins
allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS", "HEAD", "PATCH"],
allow_headers=["*"],
expose_headers=[
"Location",
"Access-Control-Allow-Origin",
"Access-Control-Allow-Methods",
"Access-Control-Allow-Headers",
"Access-Control-Allow-Credentials",
"Access-Control-Expose-Headers",
],
max_age=3600,
)
@app.middleware("http")
async def handle_redirects(request: Request, call_next):
"""Ensure CORS headers are in redirect responses and force https in the 'Location' header."""
response = await call_next(request)
response.headers["Access-Control-Allow-Origin"] = FRONTEND_URL
response.headers["Access-Control-Allow-Methods"] = (
"GET, POST, PUT, DELETE, OPTIONS, HEAD, PATCH"
)
response.headers["Access-Control-Allow-Headers"] = "*"
response.headers["Access-Control-Max-Age"] = "3600"
if response.status_code in [301, 302, 307, 308]:
response.headers["Access-Control-Expose-Headers"] = "Location"
if "Location" in response.headers:
location = response.headers["Location"]
if location.startswith("http://"):
response.headers["Location"] = location.replace(
"http://", "https://", 1
)
return response
@app.middleware("http")
async def security_headers(request: Request, call_next):
"""Add security-related headers to all responses."""
response = await call_next(request)
response.headers["X-Content-Type-Options"] = "nosniff"
response.headers["X-Frame-Options"] = "DENY" # Stricter than SAMEORIGIN
response.headers["X-XSS-Protection"] = "1; mode=block"
response.headers["Strict-Transport-Security"] = "max-age=31536000; includeSubDomains; preload"
response.headers["Content-Security-Policy"] = (
"default-src 'self'; "
"script-src 'self' 'unsafe-inline' 'unsafe-eval'; "
"style-src 'self' 'unsafe-inline'; "
"img-src 'self' data: https:; "
"connect-src 'self' "
)
response.headers["Permissions-Policy"] = (
"accelerometer=(), camera=(), geolocation=(), gyroscope=(), "
"magnetometer=(), microphone=(), payment=(), usb=()"
)
response.headers["Cache-Control"] = "no-store, no-cache, must-revalidate, proxy-revalidate"
response.headers["Pragma"] = "no-cache"
response.headers["Expires"] = "0"
return response
@app.middleware("http")
async def log_requests(request: Request, call_next):
"""Middleware to log all requests and responses."""
start_time = time.time()
logger.info(
"Incoming request",
extra={
"method": request.method,
"url": str(request.url),
"client_host": request.client.host if request.client else None,
"timestamp": datetime.utcnow().isoformat(),
},
)
try:
response = await call_next(request)
process_time = time.time() - start_time
logger.info(
"Request completed",
extra={
"method": request.method,
"url": str(request.url),
"status_code": response.status_code,
"process_time_ms": round(process_time * 1000, 2),
},
)
return response
except Exception as e:
logger.error(
"Request failed",
extra={"method": request.method, "url": str(request.url), "error": str(e)},
)
raise
app.include_router(health.router)
app.include_router(wagons.router)
app.include_router(chat.router)
app.include_router(players.router)
app.include_router(generate.router)
@app.get("/")
async def root():
logger.info("Root endpoint accessed")
return {
"message": "Welcome to Game Jam API",
"docs_url": "/docs",
"health_check": "/health",
"wagons_endpoint": "/api/wagons",
"chat_endpoint": "/api/chat",
"players_endpoint": "/api/players",
}
@app.on_event("startup")
async def startup_event():
logs_dir = Path("logs")
logs_dir.mkdir(exist_ok=True)
setup_logging()