GitHub Actions
Deploy backend from GitHub Actions
84b0fa3
"""
Comprehensive test suite for authentication system.
Tests cover:
- JWT token handling
- Google OAuth flow (mocked)
- CSRF protection
- Rate limiting
- Session management
- User preferences
- Anonymous access limits
"""
import pytest
import json
import time
from datetime import datetime, timedelta
from unittest.mock import Mock, patch, AsyncMock
from fastapi.testclient import TestClient
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from jose import jwt
# Import modules to test
from main import app
from database.config import get_db, Base
from src.models.auth import User, Session, UserPreferences
from auth.auth import create_access_token, verify_token
from middleware.csrf import CSRFMiddleware
from middleware.auth import AuthMiddleware
# Setup test database
SQLALCHEMY_DATABASE_URL = "sqlite:///./test.db"
engine = create_engine(SQLALCHEMY_DATABASE_URL, connect_args={"check_same_thread": False})
TestingSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
# Create test client
client = TestClient(app)
# Override dependency for testing
def override_get_db():
try:
db = TestingSessionLocal()
yield db
finally:
db.close()
app.dependency_overrides[get_db] = override_get_db
# Setup and teardown
@pytest.fixture(scope="module")
def setup_database():
"""Create test database tables."""
Base.metadata.create_all(bind=engine)
yield
Base.metadata.drop_all(bind=engine)
@pytest.fixture
def db_session():
"""Create a fresh database session for each test."""
db = TestingSessionLocal()
try:
yield db
finally:
db.close()
@pytest.fixture
def test_user(db_session):
"""Create a test user."""
user = User(
email="test@example.com",
name="Test User",
email_verified=True
)
db_session.add(user)
db_session.commit()
db_session.refresh(user)
return user
@pytest.fixture
def auth_headers(test_user):
"""Create authentication headers for test user."""
token = create_access_token(data={"sub": str(test_user.id), "email": test_user.email})
return {"Authorization": f"Bearer {token}"}
class TestJWTTokenHandling:
"""Test JWT token creation and validation."""
def test_create_access_token(self, test_user):
"""Test JWT token creation."""
token = create_access_token(data={"sub": str(test_user.id), "email": test_user.email})
assert token is not None
assert isinstance(token, str)
def test_verify_valid_token(self, test_user):
"""Test successful token verification."""
token = create_access_token(data={"sub": str(test_user.id), "email": test_user.email})
payload = verify_token(token)
assert payload is not None
assert payload["sub"] == str(test_user.id)
assert payload["email"] == test_user.email
def test_verify_invalid_token(self):
"""Test rejection of invalid token."""
payload = verify_token("invalid_token")
assert payload is None
def test_verify_expired_token(self, test_user):
"""Test rejection of expired token."""
# Create token that's already expired
expired_token = jwt.encode(
{"sub": str(test_user.id), "exp": datetime.utcnow() - timedelta(minutes=1)},
"test_secret",
algorithm="HS256"
)
with patch('auth.auth.JWT_SECRET_KEY', "test_secret"):
payload = verify_token(expired_token)
assert payload is None
class TestCSRFProtection:
"""Test CSRF middleware functionality."""
def test_csrf_token_generation(self):
"""Test CSRF token generation."""
middleware = CSRFMiddleware(app)
token = middleware._get_or_generate_token(Mock())
assert token is not None
assert len(token) > 0
def test_csrf_token_validation_success(self):
"""Test successful CSRF token validation."""
middleware = CSRFMiddleware(app)
token = "test_token"
request = Mock()
request.headers = {"X-CSRF-Token": token}
# Mock the expected token
middleware._get_or_generate_token = Mock(return_value=token)
# Should not raise exception
import asyncio
asyncio.run(middleware._validate_csrf_token(request, token))
def test_csrf_token_validation_failure(self):
"""Test CSRF token validation failure."""
middleware = CSRFMiddleware(app)
request = Mock()
request.headers = {"X-CSRF-Token": "wrong_token"}
with pytest.raises(Exception): # HTTPException in real implementation
import asyncio
asyncio.run(middleware._validate_csrf_token(request, "correct_token"))
class TestAuthenticationEndpoints:
"""Test authentication API endpoints."""
def test_get_current_user_unauthorized(self):
"""Test accessing user info without authentication."""
response = client.get("/auth/me")
assert response.status_code == 401
def test_get_current_authorized(self, auth_headers):
"""Test accessing user info with valid authentication."""
# Mock the get_current_active_user dependency
with patch('routes.auth.get_current_active_user') as mock_user:
mock_user.return_value = Mock(id=1, email="test@example.com", name="Test User")
response = client.get("/auth/me", headers=auth_headers)
assert response.status_code == 200
data = response.json()
assert data["email"] == "test@example.com"
def test_logout_success(self, auth_headers):
"""Test successful logout."""
with patch('routes.auth.get_current_active_user') as mock_user:
mock_user.return_value = Mock(id=1, email="test@example.com")
response = client.post("/auth/logout", headers=auth_headers)
assert response.status_code == 200
data = response.json()
assert "message" in data
def test_google_oauth_redirect(self):
"""Test Google OAuth initiation."""
response = client.get("/auth/login/google")
assert response.status_code in [302, 200] # Redirect or mock response
class TestRateLimiting:
"""Test rate limiting functionality."""
def test_rate_limit_headers(self, auth_headers):
"""Test that rate limit headers are present."""
with patch('routes.auth.get_current_active_user') as mock_user:
mock_user.return_value = Mock(id=1, email="test@example.com")
response = client.get("/auth/me", headers=auth_headers)
# Rate limiting headers should be present if implemented
# This test would need adjustment based on actual rate limiting implementation
def test_rate_limit_exceeded(self):
"""Test behavior when rate limit is exceeded."""
# Make multiple rapid requests
for _ in range(50): # Assuming rate limit is less than 50 requests
response = client.get("/auth/login/google")
if response.status_code == 429:
assert "rate limit" in response.text.lower()
break
else:
# If we didn't hit the rate limit, that's also valid for testing
assert True
class TestSessionManagement:
"""Test session creation and validation."""
def test_create_user_session(self, db_session, test_user):
"""Test creating a user session."""
from auth.auth import create_user_session
token = create_user_session(db_session, test_user)
assert token is not None
# Verify session in database
session = db_session.query(Session).filter(
Session.user_id == test_user.id
).first()
assert session is not None
assert session.token == token
def test_invalidate_user_sessions(self, db_session, test_user):
"""Test invalidating all user sessions."""
from auth.auth import create_user_session, invalidate_user_sessions
# Create multiple sessions
token1 = create_user_session(db_session, test_user)
token2 = create_user_session(db_session, test_user)
# Invalidate sessions
invalidate_user_sessions(db_session, test_user)
# Check sessions are invalidated (deleted)
sessions = db_session.query(Session).filter(
Session.user_id == test_user.id
).all()
assert len(sessions) == 0
class TestAnonymousAccess:
"""Test anonymous user access and limits."""
def test_anonymous_session_creation(self):
"""Test creating anonymous session."""
middleware = AuthMiddleware(app)
# Mock request without session ID
request = Mock()
request.headers = {}
request.state = Mock()
import asyncio
asyncio.run(middleware._handle_anonymous_request(request))
# Should have session_id in state
assert hasattr(request.state, 'session_id')
assert request.state.anonymous is True
def test_anonymous_message_limit(self):
"""Test anonymous user message limit."""
middleware = AuthMiddleware(app, anonymous_limit=2)
# Create session with 2 messages (at limit)
session_id = "test_session"
middleware._anonymous_sessions[session_id] = {
"message_count": 2,
"created_at": datetime.utcnow(),
"last_activity": datetime.utcnow()
}
# Mock request with session at limit
request = Mock()
request.headers = {"X-Anonymous-Session-ID": session_id}
request.state = Mock()
# Should raise exception for exceeding limit
with pytest.raises(Exception): # HTTPException in real implementation
import asyncio
asyncio.run(middleware._handle_anonymous_request(request))
class TestUserPreferences:
"""Test user preferences management."""
def test_get_user_preferences_not_found(self, auth_headers):
"""Test getting preferences when none exist."""
with patch('routes.auth.get_current_active_user') as mock_user:
mock_user.return_value = Mock(id=999, email="test@example.com")
response = client.get("/auth/preferences", headers=auth_headers)
# Should create default preferences
assert response.status_code == 200
data = response.json()
assert "theme" in data
assert data["theme"] == "light"
def test_update_user_preferences(self, auth_headers):
"""Test updating user preferences."""
with patch('routes.auth.get_current_active_user') as mock_user:
mock_user.return_value = Mock(id=1, email="test@example.com")
preferences = {
"theme": "dark",
"language": "en",
"notifications_enabled": False,
"chat_settings": {"model": "gpt-4"}
}
response = client.put("/auth/preferences",
json=preferences,
headers=auth_headers)
assert response.status_code == 200
data = response.json()
assert data["theme"] == "dark"
assert data["notifications_enabled"] is False
class TestSecurityHeaders:
"""Test security-related headers."""
def test_cors_headers(self):
"""Test CORS headers are present."""
response = client.options("/auth/me")
assert "access-control-allow-origin" in response.headers
def test_csrf_cookie_set(self):
"""Test CSRF cookie is set on first request."""
response = client.get("/auth/me")
# CSRF middleware should set cookie
# This test depends on actual implementation details
class TestErrorHandling:
"""Test error handling in authentication."""
def test_invalid_token_format(self):
"""Test handling of malformed tokens."""
response = client.get("/auth/me",
headers={"Authorization": "Bearer invalid.token.format"})
assert response.status_code == 401
def test_missing_authorization_header(self):
"""Test request without Authorization header."""
response = client.get("/auth/me")
assert response.status_code == 401
def test_sql_injection_attempts(self):
"""Test SQL injection protection."""
malicious_input = "'; DROP TABLE users; --"
response = client.post("/auth/logout",
json={"email": malicious_input})
# Should handle gracefully without executing SQL
assert response.status_code in [401, 422]
# Integration Tests
class TestAuthenticationFlow:
"""Test complete authentication flow."""
def test_full_oauth_flow_simulation(self):
"""Simulate complete OAuth flow."""
# This would mock the entire OAuth flow
with patch('routes.auth.oauth.google.authorize_redirect') as mock_auth:
with patch('routes.auth.get_or_create_user') as mock_user:
with patch('routes.auth.create_user_session') as mock_session:
with patch('routes.auth.create_access_token') as mock_token:
# Setup mocks
mock_auth.return_value = Mock()
mock_user.return_value = Mock(id=1, email="test@example.com")
mock_session.return_value = "session_token"
mock_token.return_value = "jwt_token"
# Test OAuth initiation
response = client.get("/auth/login/google")
assert response.status_code in [200, 302]
def test_session_expiry(self, db_session, test_user):
"""Test session expiration handling."""
from auth.auth import create_user_session, check_session_validity
# Create expired session
expired_time = datetime.utcnow() - timedelta(days=1)
# Direct database manipulation for testing
session = Session(
user_id=test_user.id,
token="expired_token",
expires_at=expired_time
)
db_session.add(session)
db_session.commit()
# Check validity
is_valid = check_session_validity("expired_token", db_session)
assert is_valid is False
if __name__ == "__main__":
# Run tests
pytest.main([__file__, "-v", "--tb=short"])