Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| import os | |
| from dataclasses import dataclass | |
| from typing import Any | |
| from fastapi import Depends, HTTPException, Request, status | |
| from itsdangerous import BadSignature, SignatureExpired, URLSafeTimedSerializer | |
| from sqlalchemy.orm import Session | |
| from starlette.middleware.sessions import SessionMiddleware | |
| from data import crud | |
| from data.db import get_db | |
| AUTH_MODE_DEV = "dev" | |
| AUTH_MODE_HF = "hf_oauth" | |
| AUTH_BRIDGE_SALT = "streamlit-auth-bridge" | |
| DEFAULT_DEV_SESSION_SECRET = "dev-only-session-secret-change-me" | |
| class CurrentUser: | |
| id: int | |
| email: str | |
| display_name: str | None = None | |
| avatar_url: str | None = None | |
| class AuthBridgeTokenError(ValueError): | |
| pass | |
| def get_auth_mode() -> str: | |
| configured = os.getenv("AUTH_MODE", AUTH_MODE_DEV).strip().lower() | |
| return configured if configured in {AUTH_MODE_DEV, AUTH_MODE_HF} else AUTH_MODE_DEV | |
| def configure_session_middleware(app) -> None: | |
| """Attach Starlette session middleware once during app setup.""" | |
| secret = os.getenv("APP_SESSION_SECRET", DEFAULT_DEV_SESSION_SECRET).strip() | |
| auth_mode = get_auth_mode() | |
| if auth_mode == AUTH_MODE_HF and (not secret or secret == DEFAULT_DEV_SESSION_SECRET): | |
| raise RuntimeError("APP_SESSION_SECRET must be set to a non-default value in hf_oauth mode.") | |
| same_site = os.getenv("SESSION_COOKIE_SAMESITE", "lax").strip().lower() | |
| if same_site not in {"lax", "strict", "none"}: | |
| same_site = "lax" | |
| secure_default = "1" if auth_mode == AUTH_MODE_HF else "0" | |
| https_only = os.getenv("SESSION_COOKIE_SECURE", secure_default).strip().lower() in { | |
| "1", | |
| "true", | |
| "yes", | |
| "on", | |
| } | |
| app.add_middleware( | |
| SessionMiddleware, | |
| secret_key=secret, | |
| same_site=same_site, | |
| https_only=https_only, | |
| max_age=60 * 60 * 24 * 7, # 7 days | |
| ) | |
| def _bridge_serializer() -> URLSafeTimedSerializer: | |
| secret = os.getenv("APP_SESSION_SECRET", DEFAULT_DEV_SESSION_SECRET) | |
| return URLSafeTimedSerializer(secret_key=secret, salt=AUTH_BRIDGE_SALT) | |
| def _session_user_to_current_user(session_user: dict[str, Any]) -> CurrentUser | None: | |
| try: | |
| return CurrentUser( | |
| id=int(session_user["id"]), | |
| email=str(session_user["email"]), | |
| display_name=(str(session_user["display_name"]) if session_user.get("display_name") else None), | |
| avatar_url=(str(session_user["avatar_url"]) if session_user.get("avatar_url") else None), | |
| ) | |
| except (KeyError, TypeError, ValueError): | |
| return None | |
| def get_session_user(request: Request) -> CurrentUser | None: | |
| raw_user = request.session.get("user") | |
| if not isinstance(raw_user, dict): | |
| return None | |
| return _session_user_to_current_user(raw_user) | |
| def set_session_user(request: Request, user: CurrentUser) -> None: | |
| request.session["user"] = { | |
| "id": user.id, | |
| "email": user.email, | |
| "display_name": user.display_name, | |
| "avatar_url": user.avatar_url, | |
| } | |
| def clear_session_user(request: Request) -> None: | |
| request.session.pop("user", None) | |
| def generate_auth_bridge_token(user: CurrentUser) -> str: | |
| payload = { | |
| "id": user.id, | |
| "email": user.email, | |
| "display_name": user.display_name, | |
| "avatar_url": user.avatar_url, | |
| } | |
| return _bridge_serializer().dumps(payload) | |
| def decode_auth_bridge_token(token: str) -> CurrentUser: | |
| ttl_seconds = int(os.getenv("AUTH_BRIDGE_TOKEN_TTL_SECONDS", "300")) | |
| try: | |
| payload = _bridge_serializer().loads(token, max_age=ttl_seconds) | |
| except SignatureExpired as exc: | |
| raise AuthBridgeTokenError("Bridge token expired. Please sign in again.") from exc | |
| except BadSignature as exc: | |
| raise AuthBridgeTokenError("Invalid bridge token.") from exc | |
| if not isinstance(payload, dict): | |
| raise AuthBridgeTokenError("Invalid bridge token payload.") | |
| user = _session_user_to_current_user(payload) | |
| if user is None: | |
| raise AuthBridgeTokenError("Invalid bridge token payload.") | |
| return user | |
| def _ensure_dev_user(request: Request, db: Session) -> CurrentUser: | |
| dev_user_id = int(os.getenv("AUTH_DEV_USER_ID", "1")) | |
| dev_email = os.getenv("AUTH_DEV_EMAIL", "dev@example.com") | |
| dev_name = os.getenv("AUTH_DEV_DISPLAY_NAME", "Dev User") | |
| user = crud.get_or_create_user( | |
| db=db, | |
| user_id=dev_user_id, | |
| email=dev_email, | |
| display_name=dev_name, | |
| ) | |
| current = CurrentUser( | |
| id=user.id, | |
| email=user.email, | |
| display_name=user.display_name, | |
| avatar_url=user.avatar_url, | |
| ) | |
| set_session_user(request, current) | |
| return current | |
| def require_current_user(request: Request, db: Session = Depends(get_db)) -> CurrentUser: | |
| """Resolve current user from session; auto-provision in dev mode.""" | |
| session_user = get_session_user(request) | |
| if session_user: | |
| return session_user | |
| if get_auth_mode() == AUTH_MODE_DEV: | |
| return _ensure_dev_user(request, db) | |
| raise HTTPException( | |
| status_code=status.HTTP_401_UNAUTHORIZED, | |
| detail="Authentication required.", | |
| ) | |