champ-chatbot / tests /api /test_chat_post.py
qyle's picture
deployment
8b9e569 verified
import pytest
from fastapi.testclient import TestClient
from unittest.mock import Mock, patch
from main import app
client = TestClient(app)
class TestChatEndpoint:
"""Test the POST /chat endpoint"""
@pytest.fixture
def base_required_fields(self):
"""Base fields required by IdentifierBase and ProfileBase"""
return {
"user_id": "test-user-123",
"participant_id": "participant-456",
"session_id": "test-session-123",
"consent": True,
"age_group": "25-34",
"gender": "M",
"roles": ["patient"],
}
@pytest.fixture
def valid_payload(self, base_required_fields):
return {
**base_required_fields,
"conversation_id": "conversation-abc",
"model_type": "champ",
"lang": "en",
"human_message": "What should I do about a fever?",
}
@pytest.fixture
def mock_dependencies(self):
"""Mock all external dependencies"""
with (
patch("main.session_tracker") as mock_tracker,
patch("main.PIIFilter") as mock_pii_class,
patch("main.session_conversation_store") as mock_conv_store,
patch("main.session_document_store") as mock_doc_store,
patch("main.call_llm") as mock_call_llm,
patch("main.log_event") as mock_log_event,
):
# Setup PIIFilter
mock_pii = Mock()
mock_pii.sanitize.return_value = "sanitized message"
mock_pii_class.return_value = mock_pii
# Setup conversation store
mock_conv_store.add_human_message.return_value = [
Mock(role="user", content="sanitized message")
]
# Setup document store
mock_doc_store.get_document_contents.return_value = None
# Setup call_llm (non-streaming by default)
mock_call_llm.return_value = ("AI response", {}, [])
yield {
"tracker": mock_tracker,
"pii_filter": mock_pii,
"conv_store": mock_conv_store,
"doc_store": mock_doc_store,
"call_llm": mock_call_llm,
"log_event": mock_log_event,
}
# ==================== Successful Chat Tests ====================
def test_chat_success_non_streaming(self, valid_payload, mock_dependencies):
"""Test successful non-streaming chat response"""
response = client.post("/chat", json=valid_payload)
assert response.status_code == 200
assert response.json() == {"reply": "AI response"}
def test_chat_updates_session_tracker(self, valid_payload, mock_dependencies):
"""Test that session tracker is updated"""
client.post("/chat", json=valid_payload)
mock_dependencies["tracker"].update_session.assert_called_once_with(
"test-session-123"
)
def test_chat_sanitizes_message(self, valid_payload, mock_dependencies):
"""Test that PII filter is applied to message"""
client.post("/chat", json=valid_payload)
mock_dependencies["pii_filter"].sanitize.assert_called_once_with(
"What should I do about a fever?"
)
def test_chat_adds_human_message_to_store(self, valid_payload, mock_dependencies):
"""Test that sanitized message is added to conversation store"""
client.post("/chat", json=valid_payload)
mock_dependencies["conv_store"].add_human_message.assert_called_once_with(
"test-session-123", "conversation-abc", "sanitized message"
)
def test_chat_retrieves_documents(self, valid_payload, mock_dependencies):
"""Test that documents are retrieved from document store"""
client.post("/chat", json=valid_payload)
mock_dependencies["doc_store"].get_document_contents.assert_called_once_with(
"test-session-123"
)
def test_chat_calls_llm_with_correct_params(self, valid_payload, mock_dependencies):
"""Test that call_llm is invoked with correct parameters"""
mock_conversation = [Mock()]
mock_dependencies[
"conv_store"
].add_human_message.return_value = mock_conversation
mock_dependencies["doc_store"].get_document_contents.return_value = ["doc1"]
client.post("/chat", json=valid_payload)
# call_llm is wrapped in run_in_executor, so we need to wait
# The test client handles this synchronously
mock_dependencies["call_llm"].assert_called_once_with(
"champ", "en", mock_conversation, ["doc1"]
)
def test_chat_adds_assistant_reply_to_store(self, valid_payload, mock_dependencies):
"""Test that assistant reply is added to conversation store"""
client.post("/chat", json=valid_payload)
mock_dependencies["conv_store"].add_assistant_reply.assert_called_once_with(
"test-session-123", "conversation-abc", "AI response"
)
# ==================== Streaming Response Tests ====================
def test_chat_streaming_response(self, valid_payload, mock_dependencies):
"""Test streaming response from OpenAI"""
async def mock_stream():
yield "Hello "
yield "world"
mock_dependencies["call_llm"].return_value = mock_stream()
response = client.post("/chat", json=valid_payload)
assert response.status_code == 200
# StreamingResponse returns chunks
content = response.content.decode()
assert "Hello world" in content
# ==================== Different Model Types Tests ====================
def test_chat_openai_model(self, base_required_fields, mock_dependencies):
"""Test chat with OpenAI model"""
payload = {
**base_required_fields,
"conversation_id": "conv-1",
"model_type": "openai",
"lang": "en",
"human_message": "Hello",
}
# OpenAI returns AsyncGenerator
async def mock_stream():
yield "response"
mock_dependencies["call_llm"].return_value = mock_stream()
response = client.post("/chat", json=payload)
assert response.status_code == 200
def test_chat_google_conservative_model(
self, base_required_fields, mock_dependencies
):
"""Test chat with Google conservative model"""
payload = {
**base_required_fields,
"conversation_id": "conv-1",
"model_type": "google-conservative",
"lang": "en",
"human_message": "Hello",
}
mock_dependencies["call_llm"].return_value = ("Response", {}, [])
response = client.post("/chat", json=payload)
assert response.status_code == 200
assert response.json() == {"reply": "Response"}
def test_chat_google_creative_model(self, base_required_fields, mock_dependencies):
"""Test chat with Google creative model"""
payload = {
**base_required_fields,
"conversation_id": "conv-1",
"model_type": "google-creative",
"lang": "fr",
"human_message": "Bonjour",
}
mock_dependencies["call_llm"].return_value = ("Réponse", {}, [])
response = client.post("/chat", json=payload)
assert response.status_code == 200
assert response.json() == {"reply": "Réponse"}
# ==================== Language Tests ====================
def test_chat_french_language(self, base_required_fields, mock_dependencies):
"""Test chat with French language"""
payload = {
**base_required_fields,
"conversation_id": "conv-1",
"model_type": "champ",
"lang": "fr",
"human_message": "Comment allez-vous?",
}
response = client.post("/chat", json=payload)
assert response.status_code == 200
# ==================== Request Validation Tests ====================
def test_chat_missing_human_message(self, base_required_fields, mock_dependencies):
"""Test that missing human_message returns 422"""
payload = {
**base_required_fields,
"conversation_id": "conv-1",
"model_type": "champ",
"lang": "en",
}
response = client.post("/chat", json=payload)
assert response.status_code == 422
def test_chat_empty_human_message(self, base_required_fields, mock_dependencies):
"""Test that empty human_message is rejected"""
payload = {
**base_required_fields,
"conversation_id": "conv-1",
"model_type": "champ",
"lang": "en",
"human_message": "",
}
response = client.post("/chat", json=payload)
assert response.status_code == 422
def test_chat_invalid_model_type(self, base_required_fields, mock_dependencies):
"""Test that invalid model_type is rejected"""
payload = {
**base_required_fields,
"conversation_id": "conv-1",
"model_type": "invalid-model",
"lang": "en",
"human_message": "Hello",
}
response = client.post("/chat", json=payload)
assert response.status_code == 422
def test_chat_invalid_language(self, base_required_fields, mock_dependencies):
"""Test that invalid language is rejected"""
payload = {
**base_required_fields,
"conversation_id": "conv-1",
"model_type": "champ",
"lang": "es", # Not in Literal["en", "fr"]
"human_message": "Hello",
}
response = client.post("/chat", json=payload)
assert response.status_code == 422
def test_chat_message_too_long(self, base_required_fields, mock_dependencies):
"""Test that message exceeding MAX_MESSAGE_LENGTH is rejected"""
payload = {
**base_required_fields,
"conversation_id": "conv-1",
"model_type": "champ",
"lang": "en",
"human_message": "x" * 100000, # Assuming this exceeds limit
}
response = client.post("/chat", json=payload)
assert response.status_code == 422
def test_chat_sanitizes_html_in_message(
self, base_required_fields, mock_dependencies
):
"""Test that HTML tags are removed from human_message"""
payload = {
**base_required_fields,
"conversation_id": "conv-1",
"model_type": "champ",
"lang": "en",
"human_message": "<script>alert('xss')</script>Hello",
}
response = client.post("/chat", json=payload)
# Should succeed with sanitized message
assert response.status_code == 200
def test_chat_invalid_conversation_id(
self, base_required_fields, mock_dependencies
):
"""Test that invalid conversation_id is rejected"""
payload = {
**base_required_fields,
"conversation_id": "invalid@id!",
"model_type": "champ",
"lang": "en",
"human_message": "Hello",
}
response = client.post("/chat", json=payload)
assert response.status_code == 422
# ==================== Rate Limiting Tests ====================
@pytest.mark.enable_rate_limit
def test_chat_rate_limiting(self, valid_payload, mock_dependencies):
"""Test that rate limiting works (20 requests per minute)"""
from fastapi.testclient import TestClient
from main import app
rate_limit_client = TestClient(app)
# Make 21 rapid requests
responses = []
for i in range(21):
response = rate_limit_client.post("/chat", json=valid_payload)
responses.append(response)
# 21st should be rate limited
assert responses[-1].status_code == 429
# ==================== Integration Tests ====================
def test_chat_full_workflow(self, valid_payload, mock_dependencies):
"""Test complete chat workflow"""
mock_conversation = [Mock(role="user", content="sanitized message")]
mock_dependencies[
"conv_store"
].add_human_message.return_value = mock_conversation
mock_dependencies["doc_store"].get_document_contents.return_value = ["doc1"]
mock_dependencies["call_llm"].return_value = (
"Full response",
{"key": "value"},
["ctx"],
)
response = client.post("/chat", json=valid_payload)
assert response.status_code == 200
assert response.json() == {"reply": "Full response"}
# Verify workflow order
mock_dependencies["tracker"].update_session.assert_called_once()
mock_dependencies["pii_filter"].sanitize.assert_called_once()
mock_dependencies["conv_store"].add_human_message.assert_called_once()
mock_dependencies["doc_store"].get_document_contents.assert_called_once()
mock_dependencies["call_llm"].assert_called_once()
mock_dependencies["conv_store"].add_assistant_reply.assert_called_once()
def test_chat_with_documents(self, valid_payload, mock_dependencies):
"""Test chat when user has uploaded documents"""
mock_dependencies["doc_store"].get_document_contents.return_value = [
"Document content 1",
"Document content 2",
]
response = client.post("/chat", json=valid_payload)
assert response.status_code == 200
# TODO
# Documents should be passed to call_llm
def test_chat_multiple_messages_same_conversation(
self, base_required_fields, mock_dependencies
):
"""Test multiple messages in same conversation"""
payload1 = {
**base_required_fields,
"conversation_id": "conv-1",
"model_type": "champ",
"lang": "en",
"human_message": "First message",
}
payload2 = {
**base_required_fields,
"conversation_id": "conv-1",
"model_type": "champ",
"lang": "en",
"human_message": "Second message",
}
response1 = client.post("/chat", json=payload1)
response2 = client.post("/chat", json=payload2)
assert response1.status_code == 200
assert response2.status_code == 200
def test_chat_different_conversations_same_session(
self, base_required_fields, mock_dependencies
):
"""Test different conversations in same session"""
payload1 = {
**base_required_fields,
"conversation_id": "conv-1",
"model_type": "champ",
"lang": "en",
"human_message": "Message in conv 1",
}
payload2 = {
**base_required_fields,
"conversation_id": "conv-2",
"model_type": "champ",
"lang": "en",
"human_message": "Message in conv 2",
}
response1 = client.post("/chat", json=payload1)
response2 = client.post("/chat", json=payload2)
assert response1.status_code == 200
assert response2.status_code == 200
# ==================== Edge Cases ====================
def test_chat_special_characters_in_message(
self, base_required_fields, mock_dependencies
):
"""Test message with special characters"""
payload = {
**base_required_fields,
"conversation_id": "conv-1",
"model_type": "champ",
"lang": "en",
"human_message": "Hello! 你好 🎉 @#$%",
}
response = client.post("/chat", json=payload)
assert response.status_code == 200
def test_chat_multiline_message(self, base_required_fields, mock_dependencies):
"""Test message with newlines"""
payload = {
**base_required_fields,
"conversation_id": "conv-1",
"model_type": "champ",
"lang": "en",
"human_message": "Line 1\nLine 2\nLine 3",
}
response = client.post("/chat", json=payload)
assert response.status_code == 200
def test_chat_empty_reply_from_llm(self, valid_payload, mock_dependencies):
"""Test handling of empty reply from LLM"""
mock_dependencies["call_llm"].return_value = ("", {}, [])
response = client.post("/chat", json=valid_payload)
assert response.status_code == 200
assert response.json() == {"reply": ""}