Spaces:
Sleeping
Sleeping
from fastapi import Request, Response | |
from starlette.middleware.base import BaseHTTPMiddleware | |
from starlette.responses import JSONResponse | |
import uuid | |
from typing import Callable | |
from ..database import db_manager | |
class SessionMiddleware(BaseHTTPMiddleware): | |
"""Middleware to handle session-based database management""" | |
def __init__(self, app, require_database: bool = True): | |
super().__init__(app) | |
self.require_database = require_database | |
async def dispatch(self, request: Request, call_next: Callable) -> Response: | |
# Skip validation for OPTIONS requests (CORS preflight) | |
if request.method == "OPTIONS": | |
response = await call_next(request) | |
return response | |
# Get or generate session ID | |
session_id = request.headers.get('x-session-id') | |
if not session_id: | |
session_id = str(uuid.uuid4()) | |
# Add session ID to request state | |
request.state.session_id = session_id | |
# Check if this is a database-related endpoint | |
path = request.url.path | |
is_database_endpoint = ( | |
path.startswith('/settings/') or | |
path.startswith('/customer/api/') or | |
path.startswith('/chef/') or | |
path.startswith('/admin/') or | |
path.startswith('/analytics/') or | |
path.startswith('/tables/') or | |
path.startswith('/feedback/') or | |
path.startswith('/loyalty/') or | |
path.startswith('/selection-offers/') | |
) | |
# Skip session validation for certain endpoints | |
skip_validation_endpoints = [ | |
'/settings/databases', | |
'/settings/hotels', | |
'/settings/switch-database', | |
'/settings/switch-hotel', | |
'/settings/current-database', | |
'/settings/current-hotel' | |
] | |
# Skip validation for admin and chef routes - they handle their own database selection | |
skip_validation_paths = [ | |
'/admin/', | |
'/chef/' | |
] | |
# Check if path should skip validation | |
should_skip_path = any(path.startswith(skip_path) for skip_path in skip_validation_paths) | |
should_validate = ( | |
is_database_endpoint and | |
path not in skip_validation_endpoints and | |
not should_skip_path and | |
self.require_database | |
) | |
if should_validate: | |
# Check if session has a valid hotel context | |
current_hotel_id = db_manager.get_current_hotel_id(session_id) | |
if not current_hotel_id: | |
# Check if there's stored hotel credentials in headers | |
stored_hotel_name = request.headers.get('x-hotel-name') | |
stored_password = request.headers.get('x-hotel-password') | |
if stored_hotel_name and stored_password: | |
# Try to verify and set hotel context | |
try: | |
# Authenticate hotel using the database manager | |
hotel_id = db_manager.authenticate_hotel(stored_hotel_name, stored_password) | |
if hotel_id: | |
# Valid credentials, set hotel context | |
db_manager.set_hotel_context(session_id, hotel_id) | |
else: | |
# Invalid credentials | |
return JSONResponse( | |
status_code=401, | |
content={ | |
"detail": "Invalid hotel credentials", | |
"error_code": "HOTEL_AUTH_FAILED" | |
} | |
) | |
except Exception as e: | |
return JSONResponse( | |
status_code=500, | |
content={ | |
"detail": f"Hotel authentication failed: {str(e)}", | |
"error_code": "HOTEL_VERIFICATION_ERROR" | |
} | |
) | |
else: | |
# No hotel selected | |
return JSONResponse( | |
status_code=400, | |
content={ | |
"detail": "No hotel selected. Please select a hotel first.", | |
"error_code": "HOTEL_NOT_SELECTED" | |
} | |
) | |
# Process the request | |
response = await call_next(request) | |
# Add session ID to response headers | |
response.headers["x-session-id"] = session_id | |
return response | |
def get_session_id(request: Request) -> str: | |
"""Helper function to get session ID from request""" | |
return getattr(request.state, 'session_id', str(uuid.uuid4())) | |