Spaces:
Paused
Paused
| import pytest | |
| from fastapi.testclient import TestClient | |
| from unittest.mock import Mock, patch | |
| from main import app | |
| client = TestClient(app) | |
| class TestChatEndpoint: | |
| """Test the POST /chat endpoint""" | |
| def base_required_fields(self): | |
| """Base fields required by IdentifierBase and ProfileBase""" | |
| return { | |
| "user_id": "test-user-123", | |
| "participant_id": "participant-456", | |
| "session_id": "test-session-123", | |
| "consent": True, | |
| "age_group": "25-34", | |
| "gender": "M", | |
| "roles": ["patient"], | |
| } | |
| def valid_payload(self, base_required_fields): | |
| return { | |
| **base_required_fields, | |
| "conversation_id": "conversation-abc", | |
| "model_type": "champ", | |
| "lang": "en", | |
| "human_message": "What should I do about a fever?", | |
| } | |
| def mock_dependencies(self): | |
| """Mock all external dependencies""" | |
| with ( | |
| patch("main.session_tracker") as mock_tracker, | |
| patch("main.PIIFilter") as mock_pii_class, | |
| patch("main.session_conversation_store") as mock_conv_store, | |
| patch("main.session_document_store") as mock_doc_store, | |
| patch("main.call_llm") as mock_call_llm, | |
| patch("main.log_event") as mock_log_event, | |
| ): | |
| # Setup PIIFilter | |
| mock_pii = Mock() | |
| mock_pii.sanitize.return_value = "sanitized message" | |
| mock_pii_class.return_value = mock_pii | |
| # Setup conversation store | |
| mock_conv_store.add_human_message.return_value = [ | |
| Mock(role="user", content="sanitized message") | |
| ] | |
| # Setup document store | |
| mock_doc_store.get_document_contents.return_value = None | |
| # Setup call_llm (non-streaming by default) | |
| mock_call_llm.return_value = ("AI response", {}, []) | |
| yield { | |
| "tracker": mock_tracker, | |
| "pii_filter": mock_pii, | |
| "conv_store": mock_conv_store, | |
| "doc_store": mock_doc_store, | |
| "call_llm": mock_call_llm, | |
| "log_event": mock_log_event, | |
| } | |
| # ==================== Successful Chat Tests ==================== | |
| def test_chat_success_non_streaming(self, valid_payload, mock_dependencies): | |
| """Test successful non-streaming chat response""" | |
| response = client.post("/chat", json=valid_payload) | |
| assert response.status_code == 200 | |
| assert response.json() == {"reply": "AI response"} | |
| def test_chat_updates_session_tracker(self, valid_payload, mock_dependencies): | |
| """Test that session tracker is updated""" | |
| client.post("/chat", json=valid_payload) | |
| mock_dependencies["tracker"].update_session.assert_called_once_with( | |
| "test-session-123" | |
| ) | |
| def test_chat_sanitizes_message(self, valid_payload, mock_dependencies): | |
| """Test that PII filter is applied to message""" | |
| client.post("/chat", json=valid_payload) | |
| mock_dependencies["pii_filter"].sanitize.assert_called_once_with( | |
| "What should I do about a fever?" | |
| ) | |
| def test_chat_adds_human_message_to_store(self, valid_payload, mock_dependencies): | |
| """Test that sanitized message is added to conversation store""" | |
| client.post("/chat", json=valid_payload) | |
| mock_dependencies["conv_store"].add_human_message.assert_called_once_with( | |
| "test-session-123", "conversation-abc", "sanitized message" | |
| ) | |
| def test_chat_retrieves_documents(self, valid_payload, mock_dependencies): | |
| """Test that documents are retrieved from document store""" | |
| client.post("/chat", json=valid_payload) | |
| mock_dependencies["doc_store"].get_document_contents.assert_called_once_with( | |
| "test-session-123" | |
| ) | |
| def test_chat_calls_llm_with_correct_params(self, valid_payload, mock_dependencies): | |
| """Test that call_llm is invoked with correct parameters""" | |
| mock_conversation = [Mock()] | |
| mock_dependencies[ | |
| "conv_store" | |
| ].add_human_message.return_value = mock_conversation | |
| mock_dependencies["doc_store"].get_document_contents.return_value = ["doc1"] | |
| client.post("/chat", json=valid_payload) | |
| # call_llm is wrapped in run_in_executor, so we need to wait | |
| # The test client handles this synchronously | |
| mock_dependencies["call_llm"].assert_called_once_with( | |
| "champ", "en", mock_conversation, ["doc1"] | |
| ) | |
| def test_chat_adds_assistant_reply_to_store(self, valid_payload, mock_dependencies): | |
| """Test that assistant reply is added to conversation store""" | |
| client.post("/chat", json=valid_payload) | |
| mock_dependencies["conv_store"].add_assistant_reply.assert_called_once_with( | |
| "test-session-123", "conversation-abc", "AI response" | |
| ) | |
| # ==================== Streaming Response Tests ==================== | |
| def test_chat_streaming_response(self, valid_payload, mock_dependencies): | |
| """Test streaming response from OpenAI""" | |
| async def mock_stream(): | |
| yield "Hello " | |
| yield "world" | |
| mock_dependencies["call_llm"].return_value = mock_stream() | |
| response = client.post("/chat", json=valid_payload) | |
| assert response.status_code == 200 | |
| # StreamingResponse returns chunks | |
| content = response.content.decode() | |
| assert "Hello world" in content | |
| # ==================== Different Model Types Tests ==================== | |
| def test_chat_openai_model(self, base_required_fields, mock_dependencies): | |
| """Test chat with OpenAI model""" | |
| payload = { | |
| **base_required_fields, | |
| "conversation_id": "conv-1", | |
| "model_type": "openai", | |
| "lang": "en", | |
| "human_message": "Hello", | |
| } | |
| # OpenAI returns AsyncGenerator | |
| async def mock_stream(): | |
| yield "response" | |
| mock_dependencies["call_llm"].return_value = mock_stream() | |
| response = client.post("/chat", json=payload) | |
| assert response.status_code == 200 | |
| def test_chat_google_conservative_model( | |
| self, base_required_fields, mock_dependencies | |
| ): | |
| """Test chat with Google conservative model""" | |
| payload = { | |
| **base_required_fields, | |
| "conversation_id": "conv-1", | |
| "model_type": "google-conservative", | |
| "lang": "en", | |
| "human_message": "Hello", | |
| } | |
| mock_dependencies["call_llm"].return_value = ("Response", {}, []) | |
| response = client.post("/chat", json=payload) | |
| assert response.status_code == 200 | |
| assert response.json() == {"reply": "Response"} | |
| def test_chat_google_creative_model(self, base_required_fields, mock_dependencies): | |
| """Test chat with Google creative model""" | |
| payload = { | |
| **base_required_fields, | |
| "conversation_id": "conv-1", | |
| "model_type": "google-creative", | |
| "lang": "fr", | |
| "human_message": "Bonjour", | |
| } | |
| mock_dependencies["call_llm"].return_value = ("Réponse", {}, []) | |
| response = client.post("/chat", json=payload) | |
| assert response.status_code == 200 | |
| assert response.json() == {"reply": "Réponse"} | |
| # ==================== Language Tests ==================== | |
| def test_chat_french_language(self, base_required_fields, mock_dependencies): | |
| """Test chat with French language""" | |
| payload = { | |
| **base_required_fields, | |
| "conversation_id": "conv-1", | |
| "model_type": "champ", | |
| "lang": "fr", | |
| "human_message": "Comment allez-vous?", | |
| } | |
| response = client.post("/chat", json=payload) | |
| assert response.status_code == 200 | |
| # ==================== Request Validation Tests ==================== | |
| def test_chat_missing_human_message(self, base_required_fields, mock_dependencies): | |
| """Test that missing human_message returns 422""" | |
| payload = { | |
| **base_required_fields, | |
| "conversation_id": "conv-1", | |
| "model_type": "champ", | |
| "lang": "en", | |
| } | |
| response = client.post("/chat", json=payload) | |
| assert response.status_code == 422 | |
| def test_chat_empty_human_message(self, base_required_fields, mock_dependencies): | |
| """Test that empty human_message is rejected""" | |
| payload = { | |
| **base_required_fields, | |
| "conversation_id": "conv-1", | |
| "model_type": "champ", | |
| "lang": "en", | |
| "human_message": "", | |
| } | |
| response = client.post("/chat", json=payload) | |
| assert response.status_code == 422 | |
| def test_chat_invalid_model_type(self, base_required_fields, mock_dependencies): | |
| """Test that invalid model_type is rejected""" | |
| payload = { | |
| **base_required_fields, | |
| "conversation_id": "conv-1", | |
| "model_type": "invalid-model", | |
| "lang": "en", | |
| "human_message": "Hello", | |
| } | |
| response = client.post("/chat", json=payload) | |
| assert response.status_code == 422 | |
| def test_chat_invalid_language(self, base_required_fields, mock_dependencies): | |
| """Test that invalid language is rejected""" | |
| payload = { | |
| **base_required_fields, | |
| "conversation_id": "conv-1", | |
| "model_type": "champ", | |
| "lang": "es", # Not in Literal["en", "fr"] | |
| "human_message": "Hello", | |
| } | |
| response = client.post("/chat", json=payload) | |
| assert response.status_code == 422 | |
| def test_chat_message_too_long(self, base_required_fields, mock_dependencies): | |
| """Test that message exceeding MAX_MESSAGE_LENGTH is rejected""" | |
| payload = { | |
| **base_required_fields, | |
| "conversation_id": "conv-1", | |
| "model_type": "champ", | |
| "lang": "en", | |
| "human_message": "x" * 100000, # Assuming this exceeds limit | |
| } | |
| response = client.post("/chat", json=payload) | |
| assert response.status_code == 422 | |
| def test_chat_sanitizes_html_in_message( | |
| self, base_required_fields, mock_dependencies | |
| ): | |
| """Test that HTML tags are removed from human_message""" | |
| payload = { | |
| **base_required_fields, | |
| "conversation_id": "conv-1", | |
| "model_type": "champ", | |
| "lang": "en", | |
| "human_message": "<script>alert('xss')</script>Hello", | |
| } | |
| response = client.post("/chat", json=payload) | |
| # Should succeed with sanitized message | |
| assert response.status_code == 200 | |
| def test_chat_invalid_conversation_id( | |
| self, base_required_fields, mock_dependencies | |
| ): | |
| """Test that invalid conversation_id is rejected""" | |
| payload = { | |
| **base_required_fields, | |
| "conversation_id": "invalid@id!", | |
| "model_type": "champ", | |
| "lang": "en", | |
| "human_message": "Hello", | |
| } | |
| response = client.post("/chat", json=payload) | |
| assert response.status_code == 422 | |
| # ==================== Rate Limiting Tests ==================== | |
| def test_chat_rate_limiting(self, valid_payload, mock_dependencies): | |
| """Test that rate limiting works (20 requests per minute)""" | |
| from fastapi.testclient import TestClient | |
| from main import app | |
| rate_limit_client = TestClient(app) | |
| # Make 21 rapid requests | |
| responses = [] | |
| for i in range(21): | |
| response = rate_limit_client.post("/chat", json=valid_payload) | |
| responses.append(response) | |
| # 21st should be rate limited | |
| assert responses[-1].status_code == 429 | |
| # ==================== Integration Tests ==================== | |
| def test_chat_full_workflow(self, valid_payload, mock_dependencies): | |
| """Test complete chat workflow""" | |
| mock_conversation = [Mock(role="user", content="sanitized message")] | |
| mock_dependencies[ | |
| "conv_store" | |
| ].add_human_message.return_value = mock_conversation | |
| mock_dependencies["doc_store"].get_document_contents.return_value = ["doc1"] | |
| mock_dependencies["call_llm"].return_value = ( | |
| "Full response", | |
| {"key": "value"}, | |
| ["ctx"], | |
| ) | |
| response = client.post("/chat", json=valid_payload) | |
| assert response.status_code == 200 | |
| assert response.json() == {"reply": "Full response"} | |
| # Verify workflow order | |
| mock_dependencies["tracker"].update_session.assert_called_once() | |
| mock_dependencies["pii_filter"].sanitize.assert_called_once() | |
| mock_dependencies["conv_store"].add_human_message.assert_called_once() | |
| mock_dependencies["doc_store"].get_document_contents.assert_called_once() | |
| mock_dependencies["call_llm"].assert_called_once() | |
| mock_dependencies["conv_store"].add_assistant_reply.assert_called_once() | |
| def test_chat_with_documents(self, valid_payload, mock_dependencies): | |
| """Test chat when user has uploaded documents""" | |
| mock_dependencies["doc_store"].get_document_contents.return_value = [ | |
| "Document content 1", | |
| "Document content 2", | |
| ] | |
| response = client.post("/chat", json=valid_payload) | |
| assert response.status_code == 200 | |
| # TODO | |
| # Documents should be passed to call_llm | |
| def test_chat_multiple_messages_same_conversation( | |
| self, base_required_fields, mock_dependencies | |
| ): | |
| """Test multiple messages in same conversation""" | |
| payload1 = { | |
| **base_required_fields, | |
| "conversation_id": "conv-1", | |
| "model_type": "champ", | |
| "lang": "en", | |
| "human_message": "First message", | |
| } | |
| payload2 = { | |
| **base_required_fields, | |
| "conversation_id": "conv-1", | |
| "model_type": "champ", | |
| "lang": "en", | |
| "human_message": "Second message", | |
| } | |
| response1 = client.post("/chat", json=payload1) | |
| response2 = client.post("/chat", json=payload2) | |
| assert response1.status_code == 200 | |
| assert response2.status_code == 200 | |
| def test_chat_different_conversations_same_session( | |
| self, base_required_fields, mock_dependencies | |
| ): | |
| """Test different conversations in same session""" | |
| payload1 = { | |
| **base_required_fields, | |
| "conversation_id": "conv-1", | |
| "model_type": "champ", | |
| "lang": "en", | |
| "human_message": "Message in conv 1", | |
| } | |
| payload2 = { | |
| **base_required_fields, | |
| "conversation_id": "conv-2", | |
| "model_type": "champ", | |
| "lang": "en", | |
| "human_message": "Message in conv 2", | |
| } | |
| response1 = client.post("/chat", json=payload1) | |
| response2 = client.post("/chat", json=payload2) | |
| assert response1.status_code == 200 | |
| assert response2.status_code == 200 | |
| # ==================== Edge Cases ==================== | |
| def test_chat_special_characters_in_message( | |
| self, base_required_fields, mock_dependencies | |
| ): | |
| """Test message with special characters""" | |
| payload = { | |
| **base_required_fields, | |
| "conversation_id": "conv-1", | |
| "model_type": "champ", | |
| "lang": "en", | |
| "human_message": "Hello! 你好 🎉 @#$%", | |
| } | |
| response = client.post("/chat", json=payload) | |
| assert response.status_code == 200 | |
| def test_chat_multiline_message(self, base_required_fields, mock_dependencies): | |
| """Test message with newlines""" | |
| payload = { | |
| **base_required_fields, | |
| "conversation_id": "conv-1", | |
| "model_type": "champ", | |
| "lang": "en", | |
| "human_message": "Line 1\nLine 2\nLine 3", | |
| } | |
| response = client.post("/chat", json=payload) | |
| assert response.status_code == 200 | |
| def test_chat_empty_reply_from_llm(self, valid_payload, mock_dependencies): | |
| """Test handling of empty reply from LLM""" | |
| mock_dependencies["call_llm"].return_value = ("", {}, []) | |
| response = client.post("/chat", json=valid_payload) | |
| assert response.status_code == 200 | |
| assert response.json() == {"reply": ""} | |