Spaces:
Sleeping
Sleeping
Update src/routers/auth.py
Browse files- src/routers/auth.py +52 -34
src/routers/auth.py
CHANGED
|
@@ -2,27 +2,21 @@ from fastapi import APIRouter, Depends, HTTPException, status, Response, Request
|
|
| 2 |
from sqlalchemy.orm import Session
|
| 3 |
from typing import Any
|
| 4 |
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
| 5 |
-
|
| 6 |
from src.schemas.auth import RegisterRequest, RegisterResponse, LoginRequest, LoginResponse
|
| 7 |
from src.models.user import User
|
| 8 |
from src.utils.security import get_password_hash, verify_password, create_access_token, verify_token
|
| 9 |
from src.utils.deps import get_db
|
| 10 |
-
|
| 11 |
router = APIRouter(tags=["Authentication"])
|
| 12 |
|
| 13 |
-
|
| 14 |
@router.post("/register", response_model=RegisterResponse)
|
| 15 |
def register(user_data: RegisterRequest, db: Session = Depends(get_db)) -> RegisterResponse:
|
| 16 |
"""
|
| 17 |
Register a new user with email and password.
|
| 18 |
-
|
| 19 |
Args:
|
| 20 |
user_data: Registration request containing email and password
|
| 21 |
db: Database session dependency
|
| 22 |
-
|
| 23 |
Returns:
|
| 24 |
RegisterResponse: Created user information
|
| 25 |
-
|
| 26 |
Raises:
|
| 27 |
HTTPException: If email is invalid format (handled by Pydantic) or email already exists
|
| 28 |
"""
|
|
@@ -33,28 +27,23 @@ def register(user_data: RegisterRequest, db: Session = Depends(get_db)) -> Regis
|
|
| 33 |
status_code=status.HTTP_409_CONFLICT,
|
| 34 |
detail="An account with this email already exists"
|
| 35 |
)
|
| 36 |
-
|
| 37 |
# Validate password length (minimum 8 characters)
|
| 38 |
if len(user_data.password) < 8:
|
| 39 |
raise HTTPException(
|
| 40 |
status_code=status.HTTP_400_BAD_REQUEST,
|
| 41 |
detail="Password must be at least 8 characters"
|
| 42 |
)
|
| 43 |
-
|
| 44 |
# Create password hash (password truncation to 72 bytes is handled in get_password_hash)
|
| 45 |
password_hash = get_password_hash(user_data.password)
|
| 46 |
-
|
| 47 |
# Create new user
|
| 48 |
user = User(
|
| 49 |
email=user_data.email,
|
| 50 |
password_hash=password_hash
|
| 51 |
)
|
| 52 |
-
|
| 53 |
# Add user to database
|
| 54 |
db.add(user)
|
| 55 |
db.commit()
|
| 56 |
db.refresh(user)
|
| 57 |
-
|
| 58 |
# Return response
|
| 59 |
return RegisterResponse(
|
| 60 |
id=str(user.id),
|
|
@@ -62,26 +51,21 @@ def register(user_data: RegisterRequest, db: Session = Depends(get_db)) -> Regis
|
|
| 62 |
created_at=user.created_at
|
| 63 |
)
|
| 64 |
|
| 65 |
-
|
| 66 |
@router.post("/login", response_model=LoginResponse)
|
| 67 |
-
def login(login_data: LoginRequest, response: Response, db: Session = Depends(get_db)) -> LoginResponse:
|
| 68 |
"""
|
| 69 |
Authenticate user and return JWT token.
|
| 70 |
-
|
| 71 |
Args:
|
| 72 |
login_data: Login request containing email and password
|
| 73 |
response: FastAPI response object to set cookies
|
| 74 |
db: Database session dependency
|
| 75 |
-
|
| 76 |
Returns:
|
| 77 |
LoginResponse: JWT token and user information
|
| 78 |
-
|
| 79 |
Raises:
|
| 80 |
HTTPException: If credentials are invalid
|
| 81 |
"""
|
| 82 |
# Find user by email
|
| 83 |
user = db.query(User).filter(User.email == login_data.email).first()
|
| 84 |
-
|
| 85 |
# Check if user exists and password is correct (password truncation to 72 bytes is handled in verify_password)
|
| 86 |
if not user or not verify_password(login_data.password, user.password_hash):
|
| 87 |
raise HTTPException(
|
|
@@ -89,20 +73,43 @@ def login(login_data: LoginRequest, response: Response, db: Session = Depends(ge
|
|
| 89 |
detail="Invalid email or password",
|
| 90 |
headers={"WWW-Authenticate": "Bearer"},
|
| 91 |
)
|
| 92 |
-
|
| 93 |
# Create access token
|
| 94 |
access_token = create_access_token(data={"sub": str(user.id)})
|
| 95 |
-
|
| 96 |
# Set the token in an httpOnly cookie for security
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
|
|
|
|
|
|
| 104 |
)
|
| 105 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 106 |
# Return response
|
| 107 |
return LoginResponse(
|
| 108 |
access_token=access_token,
|
|
@@ -111,22 +118,33 @@ def login(login_data: LoginRequest, response: Response, db: Session = Depends(ge
|
|
| 111 |
email=user.email
|
| 112 |
)
|
| 113 |
|
| 114 |
-
|
| 115 |
@router.get("/me")
|
| 116 |
def get_current_user(request: Request, db: Session = Depends(get_db)):
|
| 117 |
"""
|
| 118 |
-
Get current authenticated user information.
|
| 119 |
This endpoint is used to check if a user is authenticated and get their info.
|
| 120 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 121 |
# Get the token from cookies
|
| 122 |
token = request.cookies.get("access_token")
|
| 123 |
|
| 124 |
if not token:
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 130 |
|
| 131 |
# Verify the token
|
| 132 |
payload = verify_token(token)
|
|
|
|
| 2 |
from sqlalchemy.orm import Session
|
| 3 |
from typing import Any
|
| 4 |
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
|
|
|
| 5 |
from src.schemas.auth import RegisterRequest, RegisterResponse, LoginRequest, LoginResponse
|
| 6 |
from src.models.user import User
|
| 7 |
from src.utils.security import get_password_hash, verify_password, create_access_token, verify_token
|
| 8 |
from src.utils.deps import get_db
|
|
|
|
| 9 |
router = APIRouter(tags=["Authentication"])
|
| 10 |
|
|
|
|
| 11 |
@router.post("/register", response_model=RegisterResponse)
|
| 12 |
def register(user_data: RegisterRequest, db: Session = Depends(get_db)) -> RegisterResponse:
|
| 13 |
"""
|
| 14 |
Register a new user with email and password.
|
|
|
|
| 15 |
Args:
|
| 16 |
user_data: Registration request containing email and password
|
| 17 |
db: Database session dependency
|
|
|
|
| 18 |
Returns:
|
| 19 |
RegisterResponse: Created user information
|
|
|
|
| 20 |
Raises:
|
| 21 |
HTTPException: If email is invalid format (handled by Pydantic) or email already exists
|
| 22 |
"""
|
|
|
|
| 27 |
status_code=status.HTTP_409_CONFLICT,
|
| 28 |
detail="An account with this email already exists"
|
| 29 |
)
|
|
|
|
| 30 |
# Validate password length (minimum 8 characters)
|
| 31 |
if len(user_data.password) < 8:
|
| 32 |
raise HTTPException(
|
| 33 |
status_code=status.HTTP_400_BAD_REQUEST,
|
| 34 |
detail="Password must be at least 8 characters"
|
| 35 |
)
|
|
|
|
| 36 |
# Create password hash (password truncation to 72 bytes is handled in get_password_hash)
|
| 37 |
password_hash = get_password_hash(user_data.password)
|
|
|
|
| 38 |
# Create new user
|
| 39 |
user = User(
|
| 40 |
email=user_data.email,
|
| 41 |
password_hash=password_hash
|
| 42 |
)
|
|
|
|
| 43 |
# Add user to database
|
| 44 |
db.add(user)
|
| 45 |
db.commit()
|
| 46 |
db.refresh(user)
|
|
|
|
| 47 |
# Return response
|
| 48 |
return RegisterResponse(
|
| 49 |
id=str(user.id),
|
|
|
|
| 51 |
created_at=user.created_at
|
| 52 |
)
|
| 53 |
|
|
|
|
| 54 |
@router.post("/login", response_model=LoginResponse)
|
| 55 |
+
def login(login_data: LoginRequest, request: Request, response: Response, db: Session = Depends(get_db)) -> LoginResponse:
|
| 56 |
"""
|
| 57 |
Authenticate user and return JWT token.
|
|
|
|
| 58 |
Args:
|
| 59 |
login_data: Login request containing email and password
|
| 60 |
response: FastAPI response object to set cookies
|
| 61 |
db: Database session dependency
|
|
|
|
| 62 |
Returns:
|
| 63 |
LoginResponse: JWT token and user information
|
|
|
|
| 64 |
Raises:
|
| 65 |
HTTPException: If credentials are invalid
|
| 66 |
"""
|
| 67 |
# Find user by email
|
| 68 |
user = db.query(User).filter(User.email == login_data.email).first()
|
|
|
|
| 69 |
# Check if user exists and password is correct (password truncation to 72 bytes is handled in verify_password)
|
| 70 |
if not user or not verify_password(login_data.password, user.password_hash):
|
| 71 |
raise HTTPException(
|
|
|
|
| 73 |
detail="Invalid email or password",
|
| 74 |
headers={"WWW-Authenticate": "Bearer"},
|
| 75 |
)
|
|
|
|
| 76 |
# Create access token
|
| 77 |
access_token = create_access_token(data={"sub": str(user.id)})
|
|
|
|
| 78 |
# Set the token in an httpOnly cookie for security
|
| 79 |
+
# For cross-domain (Vercel frontend + HuggingFace backend), use samesite="none" and secure=True
|
| 80 |
+
import os
|
| 81 |
+
# Detect if we're in production (HTTPS) or development (HTTP)
|
| 82 |
+
# HuggingFace Spaces always uses HTTPS, so we need cross-domain cookies
|
| 83 |
+
is_production = (
|
| 84 |
+
os.getenv("ENVIRONMENT", "").lower() == "production" or
|
| 85 |
+
os.getenv("SPACE_ID") is not None or # HuggingFace Spaces set this
|
| 86 |
+
"hf.space" in os.getenv("SPACE_HOST", "") or # HuggingFace Spaces host
|
| 87 |
+
request.url.scheme == "https" # If request is HTTPS, we're in production
|
| 88 |
)
|
| 89 |
+
|
| 90 |
+
# Always use secure=True and samesite="none" for cross-domain cookies in production
|
| 91 |
+
# This is required for Vercel (frontend) + HuggingFace (backend) setup
|
| 92 |
+
cookie_kwargs = {
|
| 93 |
+
"key": "access_token",
|
| 94 |
+
"value": access_token,
|
| 95 |
+
"httponly": True,
|
| 96 |
+
"secure": True, # Always True - required for HTTPS and cross-domain cookies
|
| 97 |
+
"max_age": 604800, # 7 days in seconds
|
| 98 |
+
"path": "/", # Ensure cookie is available for all paths
|
| 99 |
+
}
|
| 100 |
+
|
| 101 |
+
# For cross-domain cookies, we MUST use samesite="none" with secure=True
|
| 102 |
+
if is_production:
|
| 103 |
+
cookie_kwargs["samesite"] = "none"
|
| 104 |
+
else:
|
| 105 |
+
cookie_kwargs["samesite"] = "lax"
|
| 106 |
+
|
| 107 |
+
response.set_cookie(**cookie_kwargs)
|
| 108 |
+
|
| 109 |
+
# Debug logging (remove in production if needed)
|
| 110 |
+
import logging
|
| 111 |
+
logger = logging.getLogger(__name__)
|
| 112 |
+
logger.info(f"Cookie set: access_token (production={is_production}, samesite={cookie_kwargs.get('samesite')})")
|
| 113 |
# Return response
|
| 114 |
return LoginResponse(
|
| 115 |
access_token=access_token,
|
|
|
|
| 118 |
email=user.email
|
| 119 |
)
|
| 120 |
|
|
|
|
| 121 |
@router.get("/me")
|
| 122 |
def get_current_user(request: Request, db: Session = Depends(get_db)):
|
| 123 |
"""
|
|
|
|
| 124 |
This endpoint is used to check if a user is authenticated and get their info.
|
| 125 |
"""
|
| 126 |
+
# Debug: Log all cookies received
|
| 127 |
+
import logging
|
| 128 |
+
logger = logging.getLogger(__name__)
|
| 129 |
+
logger.info(f"Cookies received: {list(request.cookies.keys())}")
|
| 130 |
+
logger.info(f"Headers: {dict(request.headers)}")
|
| 131 |
+
|
| 132 |
# Get the token from cookies
|
| 133 |
token = request.cookies.get("access_token")
|
| 134 |
|
| 135 |
if not token:
|
| 136 |
+
# Also try Authorization header as fallback
|
| 137 |
+
auth_header = request.headers.get("Authorization")
|
| 138 |
+
if auth_header and auth_header.startswith("Bearer "):
|
| 139 |
+
token = auth_header.split(" ")[1]
|
| 140 |
+
logger.info("Token found in Authorization header")
|
| 141 |
+
else:
|
| 142 |
+
logger.warning("No access_token cookie found and no Authorization header")
|
| 143 |
+
raise HTTPException(
|
| 144 |
+
status_code=status.HTTP_401_UNAUTHORIZED,
|
| 145 |
+
detail="Not authenticated - no token found in cookies or headers",
|
| 146 |
+
headers={"WWW-Authenticate": "Bearer"},
|
| 147 |
+
)
|
| 148 |
|
| 149 |
# Verify the token
|
| 150 |
payload = verify_token(token)
|