|
|
from datetime import datetime, timedelta |
|
|
|
|
|
import jwt |
|
|
from dotenv import load_dotenv |
|
|
from fastapi import HTTPException, status |
|
|
from pydantic import BaseModel |
|
|
|
|
|
from .config import global_args |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
load_dotenv(dotenv_path=".env", override=False) |
|
|
|
|
|
|
|
|
class TokenPayload(BaseModel): |
|
|
sub: str |
|
|
exp: datetime |
|
|
role: str = "user" |
|
|
metadata: dict = {} |
|
|
|
|
|
|
|
|
class AuthHandler: |
|
|
def __init__(self): |
|
|
self.secret = global_args.token_secret |
|
|
self.algorithm = global_args.jwt_algorithm |
|
|
self.expire_hours = global_args.token_expire_hours |
|
|
self.guest_expire_hours = global_args.guest_token_expire_hours |
|
|
self.accounts = {} |
|
|
auth_accounts = global_args.auth_accounts |
|
|
if auth_accounts: |
|
|
for account in auth_accounts.split(","): |
|
|
username, password = account.split(":", 1) |
|
|
self.accounts[username] = password |
|
|
|
|
|
def create_token( |
|
|
self, |
|
|
username: str, |
|
|
role: str = "user", |
|
|
custom_expire_hours: int = None, |
|
|
metadata: dict = None, |
|
|
) -> str: |
|
|
""" |
|
|
Create JWT token |
|
|
|
|
|
Args: |
|
|
username: Username |
|
|
role: User role, default is "user", guest is "guest" |
|
|
custom_expire_hours: Custom expiration time (hours), if None use default value |
|
|
metadata: Additional metadata |
|
|
|
|
|
Returns: |
|
|
str: Encoded JWT token |
|
|
""" |
|
|
|
|
|
if custom_expire_hours is None: |
|
|
if role == "guest": |
|
|
expire_hours = self.guest_expire_hours |
|
|
else: |
|
|
expire_hours = self.expire_hours |
|
|
else: |
|
|
expire_hours = custom_expire_hours |
|
|
|
|
|
expire = datetime.utcnow() + timedelta(hours=expire_hours) |
|
|
|
|
|
|
|
|
payload = TokenPayload( |
|
|
sub=username, exp=expire, role=role, metadata=metadata or {} |
|
|
) |
|
|
|
|
|
return jwt.encode(payload.dict(), self.secret, algorithm=self.algorithm) |
|
|
|
|
|
def validate_token(self, token: str) -> dict: |
|
|
""" |
|
|
Validate JWT token |
|
|
|
|
|
Args: |
|
|
token: JWT token |
|
|
|
|
|
Returns: |
|
|
dict: Dictionary containing user information |
|
|
|
|
|
Raises: |
|
|
HTTPException: If token is invalid or expired |
|
|
""" |
|
|
try: |
|
|
payload = jwt.decode(token, self.secret, algorithms=[self.algorithm]) |
|
|
expire_timestamp = payload["exp"] |
|
|
expire_time = datetime.utcfromtimestamp(expire_timestamp) |
|
|
|
|
|
if datetime.utcnow() > expire_time: |
|
|
raise HTTPException( |
|
|
status_code=status.HTTP_401_UNAUTHORIZED, detail="Token expired" |
|
|
) |
|
|
|
|
|
|
|
|
return { |
|
|
"username": payload["sub"], |
|
|
"role": payload.get("role", "user"), |
|
|
"metadata": payload.get("metadata", {}), |
|
|
"exp": expire_time, |
|
|
} |
|
|
except jwt.PyJWTError: |
|
|
raise HTTPException( |
|
|
status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token" |
|
|
) |
|
|
|
|
|
|
|
|
auth_handler = AuthHandler() |
|
|
|