|
|
|
|
|
from datetime import datetime, timedelta |
|
|
from typing import Optional, Dict, Any |
|
|
from jose import JWTError, jwt |
|
|
from passlib.context import CryptContext |
|
|
from fastapi import HTTPException, status, Depends |
|
|
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials |
|
|
from pydantic import BaseModel |
|
|
|
|
|
from config import Config |
|
|
|
|
|
|
|
|
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") |
|
|
|
|
|
|
|
|
security = HTTPBearer() |
|
|
|
|
|
class TokenData(BaseModel): |
|
|
username: Optional[str] = None |
|
|
|
|
|
class User(BaseModel): |
|
|
username: str |
|
|
|
|
|
class Token(BaseModel): |
|
|
access_token: str |
|
|
token_type: str |
|
|
expires_in: int |
|
|
|
|
|
def verify_password(plain_password: str, hashed_password: str) -> bool: |
|
|
"""验证密码""" |
|
|
return pwd_context.verify(plain_password, hashed_password) |
|
|
|
|
|
def get_password_hash(password: str) -> str: |
|
|
"""生成密码哈希""" |
|
|
return pwd_context.hash(password) |
|
|
|
|
|
def authenticate_user(username: str, password: str) -> bool: |
|
|
"""认证用户""" |
|
|
if username != Config.ADMIN_USERNAME: |
|
|
return False |
|
|
|
|
|
|
|
|
|
|
|
if password != Config.ADMIN_PASSWORD: |
|
|
return False |
|
|
|
|
|
return True |
|
|
|
|
|
def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -> str: |
|
|
"""创建 JWT 访问令牌""" |
|
|
to_encode = data.copy() |
|
|
if expires_delta: |
|
|
expire = datetime.utcnow() + expires_delta |
|
|
else: |
|
|
expire = datetime.utcnow() + timedelta(minutes=Config.JWT_EXPIRE_MINUTES) |
|
|
|
|
|
to_encode.update({"exp": expire}) |
|
|
encoded_jwt = jwt.encode(to_encode, Config.JWT_SECRET_KEY, algorithm=Config.JWT_ALGORITHM) |
|
|
return encoded_jwt |
|
|
|
|
|
async def get_current_user(credentials: HTTPAuthorizationCredentials = Depends(security)) -> User: |
|
|
"""获取当前用户""" |
|
|
credentials_exception = HTTPException( |
|
|
status_code=status.HTTP_401_UNAUTHORIZED, |
|
|
detail="无法验证凭据", |
|
|
headers={"WWW-Authenticate": "Bearer"}, |
|
|
) |
|
|
|
|
|
try: |
|
|
payload = jwt.decode(credentials.credentials, Config.JWT_SECRET_KEY, algorithms=[Config.JWT_ALGORITHM]) |
|
|
username: str = payload.get("sub") |
|
|
if username is None: |
|
|
raise credentials_exception |
|
|
token_data = TokenData(username=username) |
|
|
except JWTError: |
|
|
raise credentials_exception |
|
|
|
|
|
if token_data.username != Config.ADMIN_USERNAME: |
|
|
raise credentials_exception |
|
|
|
|
|
return User(username=token_data.username) |
|
|
|
|
|
def get_current_active_user(current_user: User = Depends(get_current_user)) -> User: |
|
|
"""获取当前活跃用户""" |
|
|
return current_user |
|
|
|
|
|
|
|
|
async def get_current_authenticated_user(current_user: User = Depends(get_current_active_user)): |
|
|
"""认证用户依赖""" |
|
|
return current_user |
|
|
|
|
|
|
|
|
async def get_current_user_optional(credentials: Optional[HTTPAuthorizationCredentials] = Depends(security)) -> Optional[User]: |
|
|
"""可选认证用户依赖""" |
|
|
if not credentials: |
|
|
return None |
|
|
|
|
|
try: |
|
|
return await get_current_user(credentials) |
|
|
except HTTPException: |
|
|
return None |