|
|
""" |
|
|
管理后台认证中间件 |
|
|
""" |
|
|
from fastapi import Request, HTTPException, status |
|
|
from fastapi.responses import RedirectResponse |
|
|
from typing import Optional |
|
|
import hashlib |
|
|
import secrets |
|
|
from datetime import datetime, timedelta |
|
|
|
|
|
from app.core.config import settings |
|
|
|
|
|
|
|
|
_sessions = {} |
|
|
|
|
|
|
|
|
SESSION_EXPIRE_HOURS = 24 |
|
|
|
|
|
|
|
|
def generate_session_token() -> str: |
|
|
"""生成随机 session token""" |
|
|
return secrets.token_urlsafe(32) |
|
|
|
|
|
|
|
|
def create_session(password: str) -> Optional[str]: |
|
|
""" |
|
|
创建 session |
|
|
|
|
|
Args: |
|
|
password: 用户输入的密码 |
|
|
|
|
|
Returns: |
|
|
session_token 或 None(密码错误) |
|
|
""" |
|
|
|
|
|
if password != settings.ADMIN_PASSWORD: |
|
|
return None |
|
|
|
|
|
|
|
|
session_token = generate_session_token() |
|
|
|
|
|
|
|
|
_sessions[session_token] = { |
|
|
"created_at": datetime.now(), |
|
|
"expires_at": datetime.now() + timedelta(hours=SESSION_EXPIRE_HOURS), |
|
|
"authenticated": True |
|
|
} |
|
|
|
|
|
return session_token |
|
|
|
|
|
|
|
|
def verify_session(session_token: Optional[str]) -> bool: |
|
|
""" |
|
|
验证 session 是否有效 |
|
|
|
|
|
Args: |
|
|
session_token: Session token |
|
|
|
|
|
Returns: |
|
|
是否已认证 |
|
|
""" |
|
|
if not session_token: |
|
|
return False |
|
|
|
|
|
session = _sessions.get(session_token) |
|
|
if not session: |
|
|
return False |
|
|
|
|
|
|
|
|
if datetime.now() > session["expires_at"]: |
|
|
|
|
|
del _sessions[session_token] |
|
|
return False |
|
|
|
|
|
return session.get("authenticated", False) |
|
|
|
|
|
|
|
|
def delete_session(session_token: Optional[str]): |
|
|
"""删除 session(登出)""" |
|
|
if session_token and session_token in _sessions: |
|
|
del _sessions[session_token] |
|
|
|
|
|
|
|
|
def get_session_token_from_request(request: Request) -> Optional[str]: |
|
|
"""从请求中获取 session token""" |
|
|
return request.cookies.get("admin_session") |
|
|
|
|
|
|
|
|
async def require_auth(request: Request): |
|
|
""" |
|
|
认证依赖项:要求用户已登录 |
|
|
|
|
|
在路由中使用: |
|
|
@router.get("/admin", dependencies=[Depends(require_auth)]) |
|
|
""" |
|
|
session_token = get_session_token_from_request(request) |
|
|
|
|
|
if not verify_session(session_token): |
|
|
|
|
|
raise HTTPException( |
|
|
status_code=status.HTTP_303_SEE_OTHER, |
|
|
detail="未登录", |
|
|
headers={"Location": "/admin/login"} |
|
|
) |
|
|
|
|
|
|
|
|
def get_authenticated_user(request: Request) -> bool: |
|
|
""" |
|
|
获取当前认证状态(用于模板) |
|
|
|
|
|
Returns: |
|
|
是否已认证 |
|
|
""" |
|
|
session_token = get_session_token_from_request(request) |
|
|
return verify_session(session_token) |
|
|
|
|
|
|
|
|
def cleanup_expired_sessions(): |
|
|
"""清理过期的 session(定时任务调用)""" |
|
|
now = datetime.now() |
|
|
expired_tokens = [ |
|
|
token for token, session in _sessions.items() |
|
|
if now > session["expires_at"] |
|
|
] |
|
|
|
|
|
for token in expired_tokens: |
|
|
del _sessions[token] |
|
|
|
|
|
return len(expired_tokens) |
|
|
|