NotebookLMClone / auth /session.py
github-actions[bot]
Sync from GitHub b32aec24fbce3d4b772838e2ed3310b56213b520
3fe4567
from __future__ import annotations
import os
from dataclasses import dataclass
from typing import Any
from fastapi import Depends, HTTPException, Request, status
from itsdangerous import BadSignature, SignatureExpired, URLSafeTimedSerializer
from sqlalchemy.orm import Session
from starlette.middleware.sessions import SessionMiddleware
from data import crud
from data.db import get_db
AUTH_MODE_DEV = "dev"
AUTH_MODE_HF = "hf_oauth"
AUTH_BRIDGE_SALT = "streamlit-auth-bridge"
DEFAULT_DEV_SESSION_SECRET = "dev-only-session-secret-change-me"
@dataclass(frozen=True)
class CurrentUser:
id: int
email: str
display_name: str | None = None
avatar_url: str | None = None
class AuthBridgeTokenError(ValueError):
pass
def get_auth_mode() -> str:
configured = os.getenv("AUTH_MODE", AUTH_MODE_DEV).strip().lower()
return configured if configured in {AUTH_MODE_DEV, AUTH_MODE_HF} else AUTH_MODE_DEV
def configure_session_middleware(app) -> None:
"""Attach Starlette session middleware once during app setup."""
secret = os.getenv("APP_SESSION_SECRET", DEFAULT_DEV_SESSION_SECRET).strip()
auth_mode = get_auth_mode()
if auth_mode == AUTH_MODE_HF and (not secret or secret == DEFAULT_DEV_SESSION_SECRET):
raise RuntimeError("APP_SESSION_SECRET must be set to a non-default value in hf_oauth mode.")
same_site = os.getenv("SESSION_COOKIE_SAMESITE", "lax").strip().lower()
if same_site not in {"lax", "strict", "none"}:
same_site = "lax"
secure_default = "1" if auth_mode == AUTH_MODE_HF else "0"
https_only = os.getenv("SESSION_COOKIE_SECURE", secure_default).strip().lower() in {
"1",
"true",
"yes",
"on",
}
app.add_middleware(
SessionMiddleware,
secret_key=secret,
same_site=same_site,
https_only=https_only,
max_age=60 * 60 * 24 * 7, # 7 days
)
def _bridge_serializer() -> URLSafeTimedSerializer:
secret = os.getenv("APP_SESSION_SECRET", DEFAULT_DEV_SESSION_SECRET)
return URLSafeTimedSerializer(secret_key=secret, salt=AUTH_BRIDGE_SALT)
def _session_user_to_current_user(session_user: dict[str, Any]) -> CurrentUser | None:
try:
return CurrentUser(
id=int(session_user["id"]),
email=str(session_user["email"]),
display_name=(str(session_user["display_name"]) if session_user.get("display_name") else None),
avatar_url=(str(session_user["avatar_url"]) if session_user.get("avatar_url") else None),
)
except (KeyError, TypeError, ValueError):
return None
def get_session_user(request: Request) -> CurrentUser | None:
raw_user = request.session.get("user")
if not isinstance(raw_user, dict):
return None
return _session_user_to_current_user(raw_user)
def set_session_user(request: Request, user: CurrentUser) -> None:
request.session["user"] = {
"id": user.id,
"email": user.email,
"display_name": user.display_name,
"avatar_url": user.avatar_url,
}
def clear_session_user(request: Request) -> None:
request.session.pop("user", None)
def generate_auth_bridge_token(user: CurrentUser) -> str:
payload = {
"id": user.id,
"email": user.email,
"display_name": user.display_name,
"avatar_url": user.avatar_url,
}
return _bridge_serializer().dumps(payload)
def decode_auth_bridge_token(token: str) -> CurrentUser:
ttl_seconds = int(os.getenv("AUTH_BRIDGE_TOKEN_TTL_SECONDS", "300"))
try:
payload = _bridge_serializer().loads(token, max_age=ttl_seconds)
except SignatureExpired as exc:
raise AuthBridgeTokenError("Bridge token expired. Please sign in again.") from exc
except BadSignature as exc:
raise AuthBridgeTokenError("Invalid bridge token.") from exc
if not isinstance(payload, dict):
raise AuthBridgeTokenError("Invalid bridge token payload.")
user = _session_user_to_current_user(payload)
if user is None:
raise AuthBridgeTokenError("Invalid bridge token payload.")
return user
def _ensure_dev_user(request: Request, db: Session) -> CurrentUser:
dev_user_id = int(os.getenv("AUTH_DEV_USER_ID", "1"))
dev_email = os.getenv("AUTH_DEV_EMAIL", "dev@example.com")
dev_name = os.getenv("AUTH_DEV_DISPLAY_NAME", "Dev User")
user = crud.get_or_create_user(
db=db,
user_id=dev_user_id,
email=dev_email,
display_name=dev_name,
)
current = CurrentUser(
id=user.id,
email=user.email,
display_name=user.display_name,
avatar_url=user.avatar_url,
)
set_session_user(request, current)
return current
def require_current_user(request: Request, db: Session = Depends(get_db)) -> CurrentUser:
"""Resolve current user from session; auto-provision in dev mode."""
session_user = get_session_user(request)
if session_user:
return session_user
if get_auth_mode() == AUTH_MODE_DEV:
return _ensure_dev_user(request, db)
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Authentication required.",
)