|
|
""" |
|
|
Comprehensive test suite for authentication system. |
|
|
|
|
|
Tests cover: |
|
|
- JWT token handling |
|
|
- Google OAuth flow (mocked) |
|
|
- CSRF protection |
|
|
- Rate limiting |
|
|
- Session management |
|
|
- User preferences |
|
|
- Anonymous access limits |
|
|
""" |
|
|
|
|
|
import pytest |
|
|
import json |
|
|
import time |
|
|
from datetime import datetime, timedelta |
|
|
from unittest.mock import Mock, patch, AsyncMock |
|
|
from fastapi.testclient import TestClient |
|
|
from sqlalchemy import create_engine |
|
|
from sqlalchemy.orm import sessionmaker |
|
|
from jose import jwt |
|
|
|
|
|
|
|
|
from main import app |
|
|
from database.config import get_db, Base |
|
|
from src.models.auth import User, Session, UserPreferences |
|
|
from auth.auth import create_access_token, verify_token |
|
|
from middleware.csrf import CSRFMiddleware |
|
|
from middleware.auth import AuthMiddleware |
|
|
|
|
|
|
|
|
SQLALCHEMY_DATABASE_URL = "sqlite:///./test.db" |
|
|
engine = create_engine(SQLALCHEMY_DATABASE_URL, connect_args={"check_same_thread": False}) |
|
|
TestingSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) |
|
|
|
|
|
|
|
|
client = TestClient(app) |
|
|
|
|
|
|
|
|
def override_get_db(): |
|
|
try: |
|
|
db = TestingSessionLocal() |
|
|
yield db |
|
|
finally: |
|
|
db.close() |
|
|
|
|
|
app.dependency_overrides[get_db] = override_get_db |
|
|
|
|
|
|
|
|
@pytest.fixture(scope="module") |
|
|
def setup_database(): |
|
|
"""Create test database tables.""" |
|
|
Base.metadata.create_all(bind=engine) |
|
|
yield |
|
|
Base.metadata.drop_all(bind=engine) |
|
|
|
|
|
@pytest.fixture |
|
|
def db_session(): |
|
|
"""Create a fresh database session for each test.""" |
|
|
db = TestingSessionLocal() |
|
|
try: |
|
|
yield db |
|
|
finally: |
|
|
db.close() |
|
|
|
|
|
@pytest.fixture |
|
|
def test_user(db_session): |
|
|
"""Create a test user.""" |
|
|
user = User( |
|
|
email="test@example.com", |
|
|
name="Test User", |
|
|
email_verified=True |
|
|
) |
|
|
db_session.add(user) |
|
|
db_session.commit() |
|
|
db_session.refresh(user) |
|
|
return user |
|
|
|
|
|
@pytest.fixture |
|
|
def auth_headers(test_user): |
|
|
"""Create authentication headers for test user.""" |
|
|
token = create_access_token(data={"sub": str(test_user.id), "email": test_user.email}) |
|
|
return {"Authorization": f"Bearer {token}"} |
|
|
|
|
|
|
|
|
class TestJWTTokenHandling: |
|
|
"""Test JWT token creation and validation.""" |
|
|
|
|
|
def test_create_access_token(self, test_user): |
|
|
"""Test JWT token creation.""" |
|
|
token = create_access_token(data={"sub": str(test_user.id), "email": test_user.email}) |
|
|
assert token is not None |
|
|
assert isinstance(token, str) |
|
|
|
|
|
def test_verify_valid_token(self, test_user): |
|
|
"""Test successful token verification.""" |
|
|
token = create_access_token(data={"sub": str(test_user.id), "email": test_user.email}) |
|
|
payload = verify_token(token) |
|
|
assert payload is not None |
|
|
assert payload["sub"] == str(test_user.id) |
|
|
assert payload["email"] == test_user.email |
|
|
|
|
|
def test_verify_invalid_token(self): |
|
|
"""Test rejection of invalid token.""" |
|
|
payload = verify_token("invalid_token") |
|
|
assert payload is None |
|
|
|
|
|
def test_verify_expired_token(self, test_user): |
|
|
"""Test rejection of expired token.""" |
|
|
|
|
|
expired_token = jwt.encode( |
|
|
{"sub": str(test_user.id), "exp": datetime.utcnow() - timedelta(minutes=1)}, |
|
|
"test_secret", |
|
|
algorithm="HS256" |
|
|
) |
|
|
with patch('auth.auth.JWT_SECRET_KEY', "test_secret"): |
|
|
payload = verify_token(expired_token) |
|
|
assert payload is None |
|
|
|
|
|
|
|
|
class TestCSRFProtection: |
|
|
"""Test CSRF middleware functionality.""" |
|
|
|
|
|
def test_csrf_token_generation(self): |
|
|
"""Test CSRF token generation.""" |
|
|
middleware = CSRFMiddleware(app) |
|
|
token = middleware._get_or_generate_token(Mock()) |
|
|
assert token is not None |
|
|
assert len(token) > 0 |
|
|
|
|
|
def test_csrf_token_validation_success(self): |
|
|
"""Test successful CSRF token validation.""" |
|
|
middleware = CSRFMiddleware(app) |
|
|
token = "test_token" |
|
|
request = Mock() |
|
|
request.headers = {"X-CSRF-Token": token} |
|
|
|
|
|
|
|
|
middleware._get_or_generate_token = Mock(return_value=token) |
|
|
|
|
|
|
|
|
import asyncio |
|
|
asyncio.run(middleware._validate_csrf_token(request, token)) |
|
|
|
|
|
def test_csrf_token_validation_failure(self): |
|
|
"""Test CSRF token validation failure.""" |
|
|
middleware = CSRFMiddleware(app) |
|
|
request = Mock() |
|
|
request.headers = {"X-CSRF-Token": "wrong_token"} |
|
|
|
|
|
with pytest.raises(Exception): |
|
|
import asyncio |
|
|
asyncio.run(middleware._validate_csrf_token(request, "correct_token")) |
|
|
|
|
|
|
|
|
class TestAuthenticationEndpoints: |
|
|
"""Test authentication API endpoints.""" |
|
|
|
|
|
def test_get_current_user_unauthorized(self): |
|
|
"""Test accessing user info without authentication.""" |
|
|
response = client.get("/auth/me") |
|
|
assert response.status_code == 401 |
|
|
|
|
|
def test_get_current_authorized(self, auth_headers): |
|
|
"""Test accessing user info with valid authentication.""" |
|
|
|
|
|
with patch('routes.auth.get_current_active_user') as mock_user: |
|
|
mock_user.return_value = Mock(id=1, email="test@example.com", name="Test User") |
|
|
|
|
|
response = client.get("/auth/me", headers=auth_headers) |
|
|
assert response.status_code == 200 |
|
|
data = response.json() |
|
|
assert data["email"] == "test@example.com" |
|
|
|
|
|
def test_logout_success(self, auth_headers): |
|
|
"""Test successful logout.""" |
|
|
with patch('routes.auth.get_current_active_user') as mock_user: |
|
|
mock_user.return_value = Mock(id=1, email="test@example.com") |
|
|
|
|
|
response = client.post("/auth/logout", headers=auth_headers) |
|
|
assert response.status_code == 200 |
|
|
data = response.json() |
|
|
assert "message" in data |
|
|
|
|
|
def test_google_oauth_redirect(self): |
|
|
"""Test Google OAuth initiation.""" |
|
|
response = client.get("/auth/login/google") |
|
|
assert response.status_code in [302, 200] |
|
|
|
|
|
|
|
|
class TestRateLimiting: |
|
|
"""Test rate limiting functionality.""" |
|
|
|
|
|
def test_rate_limit_headers(self, auth_headers): |
|
|
"""Test that rate limit headers are present.""" |
|
|
with patch('routes.auth.get_current_active_user') as mock_user: |
|
|
mock_user.return_value = Mock(id=1, email="test@example.com") |
|
|
|
|
|
response = client.get("/auth/me", headers=auth_headers) |
|
|
|
|
|
|
|
|
|
|
|
def test_rate_limit_exceeded(self): |
|
|
"""Test behavior when rate limit is exceeded.""" |
|
|
|
|
|
for _ in range(50): |
|
|
response = client.get("/auth/login/google") |
|
|
if response.status_code == 429: |
|
|
assert "rate limit" in response.text.lower() |
|
|
break |
|
|
else: |
|
|
|
|
|
assert True |
|
|
|
|
|
|
|
|
class TestSessionManagement: |
|
|
"""Test session creation and validation.""" |
|
|
|
|
|
def test_create_user_session(self, db_session, test_user): |
|
|
"""Test creating a user session.""" |
|
|
from auth.auth import create_user_session |
|
|
|
|
|
token = create_user_session(db_session, test_user) |
|
|
assert token is not None |
|
|
|
|
|
|
|
|
session = db_session.query(Session).filter( |
|
|
Session.user_id == test_user.id |
|
|
).first() |
|
|
assert session is not None |
|
|
assert session.token == token |
|
|
|
|
|
def test_invalidate_user_sessions(self, db_session, test_user): |
|
|
"""Test invalidating all user sessions.""" |
|
|
from auth.auth import create_user_session, invalidate_user_sessions |
|
|
|
|
|
|
|
|
token1 = create_user_session(db_session, test_user) |
|
|
token2 = create_user_session(db_session, test_user) |
|
|
|
|
|
|
|
|
invalidate_user_sessions(db_session, test_user) |
|
|
|
|
|
|
|
|
sessions = db_session.query(Session).filter( |
|
|
Session.user_id == test_user.id |
|
|
).all() |
|
|
assert len(sessions) == 0 |
|
|
|
|
|
|
|
|
class TestAnonymousAccess: |
|
|
"""Test anonymous user access and limits.""" |
|
|
|
|
|
def test_anonymous_session_creation(self): |
|
|
"""Test creating anonymous session.""" |
|
|
middleware = AuthMiddleware(app) |
|
|
|
|
|
|
|
|
request = Mock() |
|
|
request.headers = {} |
|
|
request.state = Mock() |
|
|
|
|
|
import asyncio |
|
|
asyncio.run(middleware._handle_anonymous_request(request)) |
|
|
|
|
|
|
|
|
assert hasattr(request.state, 'session_id') |
|
|
assert request.state.anonymous is True |
|
|
|
|
|
def test_anonymous_message_limit(self): |
|
|
"""Test anonymous user message limit.""" |
|
|
middleware = AuthMiddleware(app, anonymous_limit=2) |
|
|
|
|
|
|
|
|
session_id = "test_session" |
|
|
middleware._anonymous_sessions[session_id] = { |
|
|
"message_count": 2, |
|
|
"created_at": datetime.utcnow(), |
|
|
"last_activity": datetime.utcnow() |
|
|
} |
|
|
|
|
|
|
|
|
request = Mock() |
|
|
request.headers = {"X-Anonymous-Session-ID": session_id} |
|
|
request.state = Mock() |
|
|
|
|
|
|
|
|
with pytest.raises(Exception): |
|
|
import asyncio |
|
|
asyncio.run(middleware._handle_anonymous_request(request)) |
|
|
|
|
|
|
|
|
class TestUserPreferences: |
|
|
"""Test user preferences management.""" |
|
|
|
|
|
def test_get_user_preferences_not_found(self, auth_headers): |
|
|
"""Test getting preferences when none exist.""" |
|
|
with patch('routes.auth.get_current_active_user') as mock_user: |
|
|
mock_user.return_value = Mock(id=999, email="test@example.com") |
|
|
|
|
|
response = client.get("/auth/preferences", headers=auth_headers) |
|
|
|
|
|
assert response.status_code == 200 |
|
|
data = response.json() |
|
|
assert "theme" in data |
|
|
assert data["theme"] == "light" |
|
|
|
|
|
def test_update_user_preferences(self, auth_headers): |
|
|
"""Test updating user preferences.""" |
|
|
with patch('routes.auth.get_current_active_user') as mock_user: |
|
|
mock_user.return_value = Mock(id=1, email="test@example.com") |
|
|
|
|
|
preferences = { |
|
|
"theme": "dark", |
|
|
"language": "en", |
|
|
"notifications_enabled": False, |
|
|
"chat_settings": {"model": "gpt-4"} |
|
|
} |
|
|
|
|
|
response = client.put("/auth/preferences", |
|
|
json=preferences, |
|
|
headers=auth_headers) |
|
|
assert response.status_code == 200 |
|
|
data = response.json() |
|
|
assert data["theme"] == "dark" |
|
|
assert data["notifications_enabled"] is False |
|
|
|
|
|
|
|
|
class TestSecurityHeaders: |
|
|
"""Test security-related headers.""" |
|
|
|
|
|
def test_cors_headers(self): |
|
|
"""Test CORS headers are present.""" |
|
|
response = client.options("/auth/me") |
|
|
assert "access-control-allow-origin" in response.headers |
|
|
|
|
|
def test_csrf_cookie_set(self): |
|
|
"""Test CSRF cookie is set on first request.""" |
|
|
response = client.get("/auth/me") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestErrorHandling: |
|
|
"""Test error handling in authentication.""" |
|
|
|
|
|
def test_invalid_token_format(self): |
|
|
"""Test handling of malformed tokens.""" |
|
|
response = client.get("/auth/me", |
|
|
headers={"Authorization": "Bearer invalid.token.format"}) |
|
|
assert response.status_code == 401 |
|
|
|
|
|
def test_missing_authorization_header(self): |
|
|
"""Test request without Authorization header.""" |
|
|
response = client.get("/auth/me") |
|
|
assert response.status_code == 401 |
|
|
|
|
|
def test_sql_injection_attempts(self): |
|
|
"""Test SQL injection protection.""" |
|
|
malicious_input = "'; DROP TABLE users; --" |
|
|
response = client.post("/auth/logout", |
|
|
json={"email": malicious_input}) |
|
|
|
|
|
assert response.status_code in [401, 422] |
|
|
|
|
|
|
|
|
|
|
|
class TestAuthenticationFlow: |
|
|
"""Test complete authentication flow.""" |
|
|
|
|
|
def test_full_oauth_flow_simulation(self): |
|
|
"""Simulate complete OAuth flow.""" |
|
|
|
|
|
with patch('routes.auth.oauth.google.authorize_redirect') as mock_auth: |
|
|
with patch('routes.auth.get_or_create_user') as mock_user: |
|
|
with patch('routes.auth.create_user_session') as mock_session: |
|
|
with patch('routes.auth.create_access_token') as mock_token: |
|
|
|
|
|
|
|
|
mock_auth.return_value = Mock() |
|
|
mock_user.return_value = Mock(id=1, email="test@example.com") |
|
|
mock_session.return_value = "session_token" |
|
|
mock_token.return_value = "jwt_token" |
|
|
|
|
|
|
|
|
response = client.get("/auth/login/google") |
|
|
assert response.status_code in [200, 302] |
|
|
|
|
|
def test_session_expiry(self, db_session, test_user): |
|
|
"""Test session expiration handling.""" |
|
|
from auth.auth import create_user_session, check_session_validity |
|
|
|
|
|
|
|
|
expired_time = datetime.utcnow() - timedelta(days=1) |
|
|
|
|
|
|
|
|
session = Session( |
|
|
user_id=test_user.id, |
|
|
token="expired_token", |
|
|
expires_at=expired_time |
|
|
) |
|
|
db_session.add(session) |
|
|
db_session.commit() |
|
|
|
|
|
|
|
|
is_valid = check_session_validity("expired_token", db_session) |
|
|
assert is_valid is False |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
pytest.main([__file__, "-v", "--tb=short"]) |