taahaasaif commited on
Commit
a7a60af
·
verified ·
1 Parent(s): 86cfb54

Update src/routers/auth.py

Browse files
Files changed (1) hide show
  1. src/routers/auth.py +52 -34
src/routers/auth.py CHANGED
@@ -2,27 +2,21 @@ from fastapi import APIRouter, Depends, HTTPException, status, Response, Request
2
  from sqlalchemy.orm import Session
3
  from typing import Any
4
  from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
5
-
6
  from src.schemas.auth import RegisterRequest, RegisterResponse, LoginRequest, LoginResponse
7
  from src.models.user import User
8
  from src.utils.security import get_password_hash, verify_password, create_access_token, verify_token
9
  from src.utils.deps import get_db
10
-
11
  router = APIRouter(tags=["Authentication"])
12
 
13
-
14
  @router.post("/register", response_model=RegisterResponse)
15
  def register(user_data: RegisterRequest, db: Session = Depends(get_db)) -> RegisterResponse:
16
  """
17
  Register a new user with email and password.
18
-
19
  Args:
20
  user_data: Registration request containing email and password
21
  db: Database session dependency
22
-
23
  Returns:
24
  RegisterResponse: Created user information
25
-
26
  Raises:
27
  HTTPException: If email is invalid format (handled by Pydantic) or email already exists
28
  """
@@ -33,28 +27,23 @@ def register(user_data: RegisterRequest, db: Session = Depends(get_db)) -> Regis
33
  status_code=status.HTTP_409_CONFLICT,
34
  detail="An account with this email already exists"
35
  )
36
-
37
  # Validate password length (minimum 8 characters)
38
  if len(user_data.password) < 8:
39
  raise HTTPException(
40
  status_code=status.HTTP_400_BAD_REQUEST,
41
  detail="Password must be at least 8 characters"
42
  )
43
-
44
  # Create password hash (password truncation to 72 bytes is handled in get_password_hash)
45
  password_hash = get_password_hash(user_data.password)
46
-
47
  # Create new user
48
  user = User(
49
  email=user_data.email,
50
  password_hash=password_hash
51
  )
52
-
53
  # Add user to database
54
  db.add(user)
55
  db.commit()
56
  db.refresh(user)
57
-
58
  # Return response
59
  return RegisterResponse(
60
  id=str(user.id),
@@ -62,26 +51,21 @@ def register(user_data: RegisterRequest, db: Session = Depends(get_db)) -> Regis
62
  created_at=user.created_at
63
  )
64
 
65
-
66
  @router.post("/login", response_model=LoginResponse)
67
- def login(login_data: LoginRequest, response: Response, db: Session = Depends(get_db)) -> LoginResponse:
68
  """
69
  Authenticate user and return JWT token.
70
-
71
  Args:
72
  login_data: Login request containing email and password
73
  response: FastAPI response object to set cookies
74
  db: Database session dependency
75
-
76
  Returns:
77
  LoginResponse: JWT token and user information
78
-
79
  Raises:
80
  HTTPException: If credentials are invalid
81
  """
82
  # Find user by email
83
  user = db.query(User).filter(User.email == login_data.email).first()
84
-
85
  # Check if user exists and password is correct (password truncation to 72 bytes is handled in verify_password)
86
  if not user or not verify_password(login_data.password, user.password_hash):
87
  raise HTTPException(
@@ -89,20 +73,43 @@ def login(login_data: LoginRequest, response: Response, db: Session = Depends(ge
89
  detail="Invalid email or password",
90
  headers={"WWW-Authenticate": "Bearer"},
91
  )
92
-
93
  # Create access token
94
  access_token = create_access_token(data={"sub": str(user.id)})
95
-
96
  # Set the token in an httpOnly cookie for security
97
- response.set_cookie(
98
- key="access_token",
99
- value=access_token,
100
- httponly=True,
101
- secure=False, # Set to True in production with HTTPS
102
- samesite="lax", # Protects against CSRF
103
- max_age=604800 # 7 days in seconds
 
 
104
  )
105
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
  # Return response
107
  return LoginResponse(
108
  access_token=access_token,
@@ -111,22 +118,33 @@ def login(login_data: LoginRequest, response: Response, db: Session = Depends(ge
111
  email=user.email
112
  )
113
 
114
-
115
  @router.get("/me")
116
  def get_current_user(request: Request, db: Session = Depends(get_db)):
117
  """
118
- Get current authenticated user information.
119
  This endpoint is used to check if a user is authenticated and get their info.
120
  """
 
 
 
 
 
 
121
  # Get the token from cookies
122
  token = request.cookies.get("access_token")
123
 
124
  if not token:
125
- raise HTTPException(
126
- status_code=status.HTTP_401_UNAUTHORIZED,
127
- detail="Not authenticated",
128
- headers={"WWW-Authenticate": "Bearer"},
129
- )
 
 
 
 
 
 
 
130
 
131
  # Verify the token
132
  payload = verify_token(token)
 
2
  from sqlalchemy.orm import Session
3
  from typing import Any
4
  from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
 
5
  from src.schemas.auth import RegisterRequest, RegisterResponse, LoginRequest, LoginResponse
6
  from src.models.user import User
7
  from src.utils.security import get_password_hash, verify_password, create_access_token, verify_token
8
  from src.utils.deps import get_db
 
9
  router = APIRouter(tags=["Authentication"])
10
 
 
11
  @router.post("/register", response_model=RegisterResponse)
12
  def register(user_data: RegisterRequest, db: Session = Depends(get_db)) -> RegisterResponse:
13
  """
14
  Register a new user with email and password.
 
15
  Args:
16
  user_data: Registration request containing email and password
17
  db: Database session dependency
 
18
  Returns:
19
  RegisterResponse: Created user information
 
20
  Raises:
21
  HTTPException: If email is invalid format (handled by Pydantic) or email already exists
22
  """
 
27
  status_code=status.HTTP_409_CONFLICT,
28
  detail="An account with this email already exists"
29
  )
 
30
  # Validate password length (minimum 8 characters)
31
  if len(user_data.password) < 8:
32
  raise HTTPException(
33
  status_code=status.HTTP_400_BAD_REQUEST,
34
  detail="Password must be at least 8 characters"
35
  )
 
36
  # Create password hash (password truncation to 72 bytes is handled in get_password_hash)
37
  password_hash = get_password_hash(user_data.password)
 
38
  # Create new user
39
  user = User(
40
  email=user_data.email,
41
  password_hash=password_hash
42
  )
 
43
  # Add user to database
44
  db.add(user)
45
  db.commit()
46
  db.refresh(user)
 
47
  # Return response
48
  return RegisterResponse(
49
  id=str(user.id),
 
51
  created_at=user.created_at
52
  )
53
 
 
54
  @router.post("/login", response_model=LoginResponse)
55
+ def login(login_data: LoginRequest, request: Request, response: Response, db: Session = Depends(get_db)) -> LoginResponse:
56
  """
57
  Authenticate user and return JWT token.
 
58
  Args:
59
  login_data: Login request containing email and password
60
  response: FastAPI response object to set cookies
61
  db: Database session dependency
 
62
  Returns:
63
  LoginResponse: JWT token and user information
 
64
  Raises:
65
  HTTPException: If credentials are invalid
66
  """
67
  # Find user by email
68
  user = db.query(User).filter(User.email == login_data.email).first()
 
69
  # Check if user exists and password is correct (password truncation to 72 bytes is handled in verify_password)
70
  if not user or not verify_password(login_data.password, user.password_hash):
71
  raise HTTPException(
 
73
  detail="Invalid email or password",
74
  headers={"WWW-Authenticate": "Bearer"},
75
  )
 
76
  # Create access token
77
  access_token = create_access_token(data={"sub": str(user.id)})
 
78
  # Set the token in an httpOnly cookie for security
79
+ # For cross-domain (Vercel frontend + HuggingFace backend), use samesite="none" and secure=True
80
+ import os
81
+ # Detect if we're in production (HTTPS) or development (HTTP)
82
+ # HuggingFace Spaces always uses HTTPS, so we need cross-domain cookies
83
+ is_production = (
84
+ os.getenv("ENVIRONMENT", "").lower() == "production" or
85
+ os.getenv("SPACE_ID") is not None or # HuggingFace Spaces set this
86
+ "hf.space" in os.getenv("SPACE_HOST", "") or # HuggingFace Spaces host
87
+ request.url.scheme == "https" # If request is HTTPS, we're in production
88
  )
89
+
90
+ # Always use secure=True and samesite="none" for cross-domain cookies in production
91
+ # This is required for Vercel (frontend) + HuggingFace (backend) setup
92
+ cookie_kwargs = {
93
+ "key": "access_token",
94
+ "value": access_token,
95
+ "httponly": True,
96
+ "secure": True, # Always True - required for HTTPS and cross-domain cookies
97
+ "max_age": 604800, # 7 days in seconds
98
+ "path": "/", # Ensure cookie is available for all paths
99
+ }
100
+
101
+ # For cross-domain cookies, we MUST use samesite="none" with secure=True
102
+ if is_production:
103
+ cookie_kwargs["samesite"] = "none"
104
+ else:
105
+ cookie_kwargs["samesite"] = "lax"
106
+
107
+ response.set_cookie(**cookie_kwargs)
108
+
109
+ # Debug logging (remove in production if needed)
110
+ import logging
111
+ logger = logging.getLogger(__name__)
112
+ logger.info(f"Cookie set: access_token (production={is_production}, samesite={cookie_kwargs.get('samesite')})")
113
  # Return response
114
  return LoginResponse(
115
  access_token=access_token,
 
118
  email=user.email
119
  )
120
 
 
121
  @router.get("/me")
122
  def get_current_user(request: Request, db: Session = Depends(get_db)):
123
  """
 
124
  This endpoint is used to check if a user is authenticated and get their info.
125
  """
126
+ # Debug: Log all cookies received
127
+ import logging
128
+ logger = logging.getLogger(__name__)
129
+ logger.info(f"Cookies received: {list(request.cookies.keys())}")
130
+ logger.info(f"Headers: {dict(request.headers)}")
131
+
132
  # Get the token from cookies
133
  token = request.cookies.get("access_token")
134
 
135
  if not token:
136
+ # Also try Authorization header as fallback
137
+ auth_header = request.headers.get("Authorization")
138
+ if auth_header and auth_header.startswith("Bearer "):
139
+ token = auth_header.split(" ")[1]
140
+ logger.info("Token found in Authorization header")
141
+ else:
142
+ logger.warning("No access_token cookie found and no Authorization header")
143
+ raise HTTPException(
144
+ status_code=status.HTTP_401_UNAUTHORIZED,
145
+ detail="Not authenticated - no token found in cookies or headers",
146
+ headers={"WWW-Authenticate": "Bearer"},
147
+ )
148
 
149
  # Verify the token
150
  payload = verify_token(token)