Spaces:
Sleeping
Sleeping
| # tests/test_memory_manager.py | |
| """ | |
| Run locally: `SKIP_API_TESTS=1 python -m tests.test_memory_manager` | |
| """ | |
| import os | |
| import unittest | |
| from datetime import datetime, timezone | |
| from unittest.mock import AsyncMock, MagicMock, patch | |
| from src.core.memory_manager import MemoryManager | |
| from src.data.connection import ActionFailed | |
| from src.models.account import Account | |
| from src.models.medical import MedicalMemory, SemanticSearchResult | |
| from src.models.patient import Patient | |
| from src.models.session import Message, Session | |
| from src.utils.embeddings import EmbeddingClient | |
| from src.utils.rotator import APIKeyRotator | |
| # Check an environment variable to see if API-dependent tests should be skipped | |
| SKIP_API_TESTS = os.getenv('SKIP_API_TESTS', 'false').lower() in ('true', '1', 'yes') | |
| # Use the modern unittest features for async code | |
| class TestMemoryManager(unittest.IsolatedAsyncioTestCase): | |
| def setUp(self): | |
| """Set up mocks and the MemoryManager instance before each test.""" | |
| # 1. Mock the repository dependencies | |
| self.account_repo_patcher = patch('src.core.memory_manager.account_repo') | |
| self.patient_repo_patcher = patch('src.core.memory_manager.patient_repo') | |
| self.session_repo_patcher = patch('src.core.memory_manager.session_repo') | |
| self.memory_repo_patcher = patch('src.core.memory_manager.memory_repo') | |
| self.mock_account_repo = self.account_repo_patcher.start() | |
| self.mock_patient_repo = self.patient_repo_patcher.start() | |
| self.mock_session_repo = self.session_repo_patcher.start() | |
| self.mock_memory_repo = self.memory_repo_patcher.start() | |
| # 2. Mock the service dependencies, specifically using AsyncMock for async functions | |
| self.summarise_title_patcher = patch('src.core.memory_manager.summariser.summarise_title_with_nvidia', new_callable=AsyncMock) | |
| self.summarise_gemini_patcher = patch('src.core.memory_manager.summariser.summarise_qa_with_gemini', new_callable=AsyncMock) | |
| self.summarise_nvidia_patcher = patch('src.core.memory_manager.summariser.summarise_qa_with_nvidia', new_callable=AsyncMock) | |
| self.nvidia_chat_patcher = patch('src.core.memory_manager.nvidia_chat', new_callable=AsyncMock) | |
| self.mock_summarise_title = self.summarise_title_patcher.start() | |
| self.mock_summarise_gemini = self.summarise_gemini_patcher.start() | |
| self.mock_summarise_nvidia = self.summarise_nvidia_patcher.start() | |
| self.mock_nvidia_chat = self.nvidia_chat_patcher.start() | |
| # 3. Create instances of dependencies needed for MemoryManager | |
| self.mock_embedder = MagicMock(spec=EmbeddingClient) | |
| self.mock_gemini_rotator = MagicMock(spec=APIKeyRotator) | |
| self.mock_nvidia_rotator = MagicMock(spec=APIKeyRotator) | |
| # 4. Instantiate the class under test | |
| self.manager = MemoryManager(embedder=self.mock_embedder, max_sessions_per_user=20) | |
| # 5. Common test data | |
| self.user_id = "60c72b2f9b1d8b3b3c9d8b1a" | |
| self.patient_id = "60c72b2f9b1d8b3b3c9d8b1b" | |
| self.session_id = "60c72b2f9b1d8b3b3c9d8b1c" | |
| self.now = datetime.now(timezone.utc) | |
| def tearDown(self): | |
| """Stop all patchers after each test.""" | |
| patch.stopall() | |
| # --- Account Management Tests --- | |
| def test_create_account_success(self): | |
| self.mock_account_repo.create_account.return_value = self.user_id | |
| result = self.manager.create_account(name="Dr. Test", role="Doctor") | |
| self.assertEqual(result, self.user_id) | |
| self.mock_account_repo.create_account.assert_called_once_with(name="Dr. Test", role="Doctor", specialty=None) | |
| def test_create_account_failure(self): | |
| self.mock_account_repo.create_account.side_effect = ActionFailed("DB error") | |
| result = self.manager.create_account() | |
| self.assertIsNone(result) | |
| def test_get_account_success(self): | |
| mock_account = MagicMock(spec=Account) | |
| self.mock_account_repo.get_account.return_value = mock_account | |
| result = self.manager.get_account(self.user_id) | |
| self.assertEqual(result, mock_account) | |
| self.mock_account_repo.get_account.assert_called_once_with(self.user_id) | |
| def test_get_account_failure(self): | |
| self.mock_account_repo.get_account.side_effect = ActionFailed("DB error") | |
| result = self.manager.get_account(self.user_id) | |
| self.assertIsNone(result) | |
| def test_get_all_accounts_success(self): | |
| self.mock_account_repo.get_all_accounts.return_value = [MagicMock(spec=Account)] | |
| result = self.manager.get_all_accounts() | |
| self.assertEqual(len(result), 1) | |
| self.mock_account_repo.get_all_accounts.assert_called_once_with(limit=50) | |
| def test_get_all_accounts_failure(self): | |
| self.mock_account_repo.get_all_accounts.side_effect = ActionFailed("DB error") | |
| result = self.manager.get_all_accounts() | |
| self.assertEqual(result, []) | |
| def test_search_accounts_success(self): | |
| self.mock_account_repo.search_accounts.return_value = [MagicMock(spec=Account)] | |
| result = self.manager.search_accounts("query") | |
| self.assertEqual(len(result), 1) | |
| self.mock_account_repo.search_accounts.assert_called_once_with("query", limit=10) | |
| def test_search_accounts_failure(self): | |
| self.mock_account_repo.search_accounts.side_effect = ActionFailed("DB error") | |
| result = self.manager.search_accounts("query") | |
| self.assertEqual(result, []) | |
| # --- Patient Management Tests --- | |
| def test_create_patient_success(self): | |
| self.mock_patient_repo.create_patient.return_value = self.patient_id | |
| result = self.manager.create_patient(name="John Doe", age=40) | |
| self.assertEqual(result, self.patient_id) | |
| self.mock_patient_repo.create_patient.assert_called_once_with(name="John Doe", age=40) | |
| def test_create_patient_failure(self): | |
| self.mock_patient_repo.create_patient.side_effect = ActionFailed("DB error") | |
| result = self.manager.create_patient(name="John Doe") | |
| self.assertIsNone(result) | |
| def test_get_patient_by_id_success(self): | |
| mock_patient = MagicMock(spec=Patient) | |
| self.mock_patient_repo.get_patient_by_id.return_value = mock_patient | |
| result = self.manager.get_patient_by_id(self.patient_id) | |
| self.assertEqual(result, mock_patient) | |
| self.mock_patient_repo.get_patient_by_id.assert_called_once_with(self.patient_id) | |
| def test_get_patient_by_id_failure(self): | |
| self.mock_patient_repo.get_patient_by_id.side_effect = ActionFailed("DB error") | |
| result = self.manager.get_patient_by_id(self.patient_id) | |
| self.assertIsNone(result) | |
| def test_update_patient_profile_success(self): | |
| self.mock_patient_repo.update_patient_profile.return_value = 1 | |
| updates = {"age": 41} | |
| result = self.manager.update_patient_profile(self.patient_id, updates) | |
| self.assertEqual(result, 1) | |
| self.mock_patient_repo.update_patient_profile.assert_called_once_with(self.patient_id, updates) | |
| def test_update_patient_profile_failure(self): | |
| self.mock_patient_repo.update_patient_profile.side_effect = ActionFailed("DB error") | |
| result = self.manager.update_patient_profile(self.patient_id, {}) | |
| self.assertEqual(result, 0) | |
| # --- Session Management Tests --- | |
| def test_create_session_success(self): | |
| mock_session = MagicMock(spec=Session) | |
| self.mock_session_repo.create_session.return_value = mock_session | |
| result = self.manager.create_session(self.user_id, self.patient_id) | |
| self.assertEqual(result, mock_session) | |
| self.mock_session_repo.create_session.assert_called_once_with(self.user_id, self.patient_id, "New Chat") | |
| def test_create_session_failure(self): | |
| self.mock_session_repo.create_session.side_effect = ActionFailed("DB error") | |
| result = self.manager.create_session(self.user_id, self.patient_id) | |
| self.assertIsNone(result) | |
| def test_get_user_sessions_success(self): | |
| self.mock_session_repo.get_user_sessions.return_value = [MagicMock(spec=Session)] | |
| result = self.manager.get_user_sessions(self.user_id) | |
| self.assertEqual(len(result), 1) | |
| self.mock_session_repo.get_user_sessions.assert_called_once_with(self.user_id, limit=20) | |
| def test_delete_session_success(self): | |
| self.mock_session_repo.delete_session.return_value = True | |
| result = self.manager.delete_session(self.session_id) | |
| self.assertTrue(result) | |
| self.mock_session_repo.delete_session.assert_called_once_with(self.session_id) | |
| def test_delete_session_failure(self): | |
| self.mock_session_repo.delete_session.side_effect = ActionFailed("DB error") | |
| result = self.manager.delete_session(self.session_id) | |
| self.assertFalse(result) | |
| # --- Core Business Logic (Async) Tests --- | |
| async def test_process_medical_exchange_success(self): | |
| question = "What are the side effects?" | |
| answer = "Common side effects include..." | |
| summary = "q: side effects a: common ones are..." | |
| embedding = [0.1, 0.2, 0.3] | |
| # Configure mocks | |
| self.mock_summarise_gemini.return_value = summary | |
| self.mock_embedder.embed.return_value = [embedding] | |
| self.manager._update_session_title_if_first_message = AsyncMock() | |
| # Call the method | |
| result = await self.manager.process_medical_exchange( | |
| self.session_id, self.patient_id, self.user_id, question, answer, | |
| self.mock_gemini_rotator, self.mock_nvidia_rotator | |
| ) | |
| # Assertions | |
| self.assertEqual(result, summary) | |
| self.assertEqual(self.mock_session_repo.add_message.call_count, 2) | |
| self.mock_session_repo.add_message.assert_any_call(self.session_id, question, sent_by_user=True) | |
| self.mock_session_repo.add_message.assert_any_call(self.session_id, answer, sent_by_user=False) | |
| self.mock_summarise_gemini.assert_awaited_once() | |
| self.mock_embedder.embed.assert_called_once_with([summary]) | |
| self.mock_memory_repo.create_memory.assert_called_once_with( | |
| patient_id=self.patient_id, | |
| doctor_id=self.user_id, | |
| session_id=self.session_id, | |
| summary=summary, | |
| embedding=embedding | |
| ) | |
| self.manager._update_session_title_if_first_message.assert_awaited_once() | |
| async def test_process_medical_exchange_db_failure(self): | |
| self.mock_session_repo.add_message.side_effect = ActionFailed("DB write failed") | |
| result = await self.manager.process_medical_exchange( | |
| self.session_id, self.patient_id, self.user_id, "q", "a", | |
| self.mock_gemini_rotator, self.mock_nvidia_rotator | |
| ) | |
| self.assertIsNone(result) | |
| async def test_process_medical_exchange_embedding_failure(self): | |
| self.mock_embedder.embed.side_effect = Exception("Embedding model down") | |
| self.mock_summarise_gemini.return_value = "summary" | |
| self.manager._update_session_title_if_first_message = AsyncMock() | |
| await self.manager.process_medical_exchange( | |
| self.session_id, self.patient_id, self.user_id, "q", "a", | |
| self.mock_gemini_rotator, self.mock_nvidia_rotator | |
| ) | |
| # Check that create_memory was still called, but with embedding=None | |
| self.mock_memory_repo.create_memory.assert_called_once() | |
| args, kwargs = self.mock_memory_repo.create_memory.call_args | |
| self.assertIsNone(kwargs.get("embedding")) | |
| async def test_get_enhanced_context_full(self): | |
| question = "Is this medication safe?" | |
| # Mock data | |
| mock_stm = [MagicMock(spec=MedicalMemory, summary="STM summary 1")] | |
| mock_ltm = [MagicMock(spec=SemanticSearchResult, summary="LTM summary 1")] | |
| mock_messages = [MagicMock(spec=Message, sent_by_user=True, content="Previous question")] | |
| mock_session = MagicMock(spec=Session, messages=mock_messages) | |
| # Configure mocks | |
| self.mock_memory_repo.get_recent_memories.return_value = mock_stm | |
| self.mock_nvidia_chat.return_value = "STM summary 1" | |
| self.mock_embedder.embed.return_value = [[0.5]] | |
| self.mock_memory_repo.search_memories_semantic.return_value = mock_ltm | |
| self.mock_session_repo.get_session.return_value = mock_session | |
| # Call method | |
| context = await self.manager.get_enhanced_context( | |
| self.session_id, self.patient_id, question, self.mock_nvidia_rotator | |
| ) | |
| # Assertions | |
| self.assertIn("Recent relevant medical context (STM)", context) | |
| self.assertIn("STM summary 1", context) | |
| self.assertIn("Semantically relevant medical history (LTM)", context) | |
| self.assertIn("LTM summary 1", context) | |
| self.assertIn("Current conversation", context) | |
| self.assertIn("User: Previous question", context) | |
| self.mock_memory_repo.get_recent_memories.assert_called_once_with(self.patient_id, limit=3) | |
| self.mock_nvidia_chat.assert_awaited_once() | |
| self.mock_memory_repo.search_memories_semantic.assert_called_once() | |
| self.mock_session_repo.get_session.assert_called_once_with(self.session_id) | |
| async def test_get_enhanced_context_no_ltm(self): | |
| # Configure mocks for only STM and session context | |
| self.mock_memory_repo.get_recent_memories.return_value = [MagicMock(spec=MedicalMemory, summary="STM")] | |
| self.mock_nvidia_chat.return_value = "STM" | |
| self.mock_embedder.embed.return_value = [[0.5]] | |
| self.mock_memory_repo.search_memories_semantic.return_value = [] # No LTM results | |
| self.mock_session_repo.get_session.return_value = MagicMock(spec=Session, messages=[]) | |
| context = await self.manager.get_enhanced_context( | |
| self.session_id, self.patient_id, "question", self.mock_nvidia_rotator | |
| ) | |
| self.assertIn("Recent relevant medical context (STM)", context) | |
| self.assertNotIn("Semantically relevant medical history (LTM)", context) | |
| self.assertNotIn("Current conversation", context) # No messages | |
| # --- Private Helper (Async) Tests --- | |
| async def test_update_session_title_if_first_message_success(self): | |
| question = "My leg hurts, what should I do?" | |
| mock_session = MagicMock(spec=Session, messages=[1, 2]) # Length is 2 | |
| self.manager.get_session = MagicMock(return_value=mock_session) | |
| self.manager.update_session_title = MagicMock() | |
| self.mock_summarise_title.return_value = "Leg Pain Inquiry" | |
| await self.manager._update_session_title_if_first_message( | |
| self.session_id, question, self.mock_nvidia_rotator | |
| ) | |
| self.manager.get_session.assert_called_once_with(self.session_id) | |
| self.mock_summarise_title.assert_awaited_once_with(question, self.mock_nvidia_rotator, max_words=5) | |
| self.manager.update_session_title.assert_called_once_with(self.session_id, "Leg Pain Inquiry") | |
| async def test_update_session_title_not_first_message(self): | |
| mock_session = MagicMock(spec=Session, messages=[1, 2, 3]) # Length is not 2 | |
| self.manager.get_session = MagicMock(return_value=mock_session) | |
| self.manager.update_session_title = MagicMock() | |
| await self.manager._update_session_title_if_first_message( | |
| self.session_id, "question", self.mock_nvidia_rotator | |
| ) | |
| self.manager.get_session.assert_called_once_with(self.session_id) | |
| self.mock_summarise_title.assert_not_awaited() | |
| self.manager.update_session_title.assert_not_called() | |
| async def test_generate_summary_gemini_success(self): | |
| self.mock_summarise_gemini.return_value = "Gemini summary" | |
| result = await self.manager._generate_summary("q", "a", self.mock_gemini_rotator, self.mock_nvidia_rotator) | |
| self.assertEqual(result, "Gemini summary") | |
| self.mock_summarise_gemini.assert_awaited_once() | |
| self.mock_summarise_nvidia.assert_not_awaited() | |
| async def test_generate_summary_gemini_fails_nvidia_success(self): | |
| self.mock_summarise_gemini.return_value = None # Gemini fails | |
| self.mock_summarise_nvidia.return_value = "NVIDIA summary" | |
| result = await self.manager._generate_summary("q", "a", self.mock_gemini_rotator, self.mock_nvidia_rotator) | |
| self.assertEqual(result, "NVIDIA summary") | |
| self.mock_summarise_gemini.assert_awaited_once() | |
| self.mock_summarise_nvidia.assert_awaited_once() | |
| async def test_generate_summary_all_fail(self): | |
| self.mock_summarise_gemini.return_value = None | |
| self.mock_summarise_nvidia.return_value = None | |
| result = await self.manager._generate_summary("question", "answer", self.mock_gemini_rotator, self.mock_nvidia_rotator) | |
| self.assertEqual(result, "Question: question\nAnswer: answer") | |
| async def test_filter_summaries_for_relevance_success(self): | |
| summaries = ["Summary A", "Summary B", "Summary C"] | |
| self.mock_nvidia_chat.return_value = "Summary A\nSummary C" | |
| result = await self.manager._filter_summaries_for_relevance("question", summaries, self.mock_nvidia_rotator) | |
| self.assertEqual(result, ["Summary A", "Summary C"]) | |
| self.mock_nvidia_chat.assert_awaited_once() | |
| async def test_filter_summaries_for_relevance_api_fails(self): | |
| summaries = ["Summary A", "Summary B"] | |
| self.mock_nvidia_chat.side_effect = Exception("API error") | |
| result = await self.manager._filter_summaries_for_relevance("question", summaries, self.mock_nvidia_rotator) | |
| # Should return all summaries as a fallback | |
| self.assertEqual(result, summaries) | |
| if __name__ == '__main__': | |
| unittest.main() | |