File size: 6,140 Bytes
368f485
 
 
 
 
 
 
 
21c55a3
368f485
d753c16
368f485
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import unittest
from unittest.mock import patch

from bson import ObjectId
from pymongo.errors import ConnectionFailure

from src.data.connection import ActionFailed, Collections
from src.data.repositories import medical_memory as medical_memory_repo
from src.models.medical import MedicalMemory, SemanticSearchResult
from src.utils.logger import logger
from ..base_test import BaseMongoTest


class TestMedicalMemoryRepository(BaseMongoTest):
	"""Test class for the 'happy path' of medical memory repository functions."""

	def setUp(self):
		"""Set up a clean test environment before each test."""
		super().setUp()
		self.test_collection = self._collections[Collections.MEDICAL_MEMORY]
		medical_memory_repo.init(collection_name=self.test_collection, drop=True)
		self.patient_id = str(ObjectId())
		self.doctor_id = str(ObjectId())
		self.session_id = str(ObjectId())
		self.embedding = [0.1, 0.2, 0.3]

	def test_init_functionality(self):
		"""Test that the init function correctly sets up the collection."""
		self.assertIn(self.test_collection, self.db.list_collection_names())

	def test_create_memory(self):
		"""Test successful creation of a medical memory with and without optional fields."""
		# Test full creation
		memory_id = medical_memory_repo.create_memory(
			self.patient_id, self.doctor_id, "Full summary", self.session_id, self.embedding,
			collection_name=self.test_collection
		)
		self.assertIsInstance(memory_id, str)
		doc = self.get_doc_by_id(Collections.MEDICAL_MEMORY, memory_id)
		self.assertIsNotNone(doc)
		self.assertEqual(doc["summary"], "Full summary") # type: ignore
		self.assertEqual(len(doc["embedding"]), 3) # type: ignore

		# Test minimal creation
		min_id = medical_memory_repo.create_memory(
			self.patient_id, self.doctor_id, "Minimal summary", collection_name=self.test_collection
		)
		self.assertIsInstance(min_id, str)

	def test_get_recent_memories(self):
		"""Test retrieving recent memories, verifying sorting, filtering, and limit."""
		medical_memory_repo.create_memory(self.patient_id, self.doctor_id, "Oldest", collection_name=self.test_collection)
		medical_memory_repo.create_memory(str(ObjectId()), self.doctor_id, "Other Patient", collection_name=self.test_collection)
		medical_memory_repo.create_memory(self.patient_id, self.doctor_id, "Newest", collection_name=self.test_collection)

		memories = medical_memory_repo.get_recent_memories(self.patient_id, collection_name=self.test_collection)
		self.assertEqual(len(memories), 2)
		self.assertIsInstance(memories[0], MedicalMemory)
		self.assertEqual(memories[0].summary, "Newest") # Descending sort order

		# Test limit
		limited = medical_memory_repo.get_recent_memories(self.patient_id, limit=1, collection_name=self.test_collection)
		self.assertEqual(len(limited), 1)

	def test_search_memories_semantic(self):
		"""Test semantic search functionality, verifying similarity logic and sorting."""
		# Create memories with known embeddings
		vec_a = [1.0, 0.0, 0.0] # Most similar
		vec_b = [0.7, 0.7, 0.0] # Less similar
		vec_c = [0.0, 0.0, 1.0] # Not similar
		medical_memory_repo.create_memory(self.patient_id, self.doctor_id, "Vec A", embedding=vec_a, collection_name=self.test_collection)
		medical_memory_repo.create_memory(self.patient_id, self.doctor_id, "Vec B", embedding=vec_b, collection_name=self.test_collection)
		medical_memory_repo.create_memory(self.patient_id, self.doctor_id, "Vec C", embedding=vec_c, collection_name=self.test_collection)
		medical_memory_repo.create_memory(self.patient_id, self.doctor_id, "No Embedding", collection_name=self.test_collection)

		query_embedding = [0.9, 0.1, 0.0]
		results = medical_memory_repo.search_memories_semantic(self.patient_id, query_embedding, collection_name=self.test_collection)

		self.assertEqual(len(results), 3) # Vec C should be filtered by default numpy math
		self.assertIsInstance(results[0], SemanticSearchResult)
		self.assertEqual(results[0].summary, "Vec A") # Most similar should be first
		self.assertGreater(results[0].similarity_score, results[1].similarity_score)


class TestMedicalMemoryRepositoryExceptions(BaseMongoTest):
	"""Test class for the exception handling of medical memory repository functions."""

	def setUp(self):
		"""Set up the test environment before each test."""
		super().setUp()
		self.test_collection = self._collections[Collections.MEDICAL_MEMORY]
		medical_memory_repo.init(collection_name=self.test_collection, drop=True)
		self.patient_id = str(ObjectId())
		self.doctor_id = str(ObjectId())

	def test_invalid_id_raises_action_failed(self):
		"""Test that functions raise ActionFailed when given a malformed ObjectId string."""
		with self.assertRaises(ActionFailed):
			medical_memory_repo.create_memory("bad-id", self.doctor_id, "t", collection_name=self.test_collection)
		with self.assertRaises(ActionFailed):
			medical_memory_repo.get_recent_memories("bad-id", collection_name=self.test_collection)
		with self.assertRaises(ActionFailed):
			medical_memory_repo.search_memories_semantic("bad-id", [], collection_name=self.test_collection)

	@patch('src.data.repositories.medical_memory.get_collection')
	def test_all_functions_raise_on_connection_error(self, mock_get_collection):
		"""Test that all repo functions catch generic PyMongoErrors and raise ActionFailed."""
		mock_get_collection.side_effect = ConnectionFailure("Simulated connection error")

		with self.assertRaises(ActionFailed):
			medical_memory_repo.init(collection_name=self.test_collection, drop=True)
		with self.assertRaises(ActionFailed):
			medical_memory_repo.create_memory(self.patient_id, self.doctor_id, "t", collection_name=self.test_collection)
		with self.assertRaises(ActionFailed):
			medical_memory_repo.get_recent_memories(self.patient_id, collection_name=self.test_collection)
		with self.assertRaises(ActionFailed):
			medical_memory_repo.search_memories_semantic(self.patient_id, [], collection_name=self.test_collection)

if __name__ == "__main__":
	logger().info("Starting MongoDB repository integration tests...")
	unittest.main(verbosity=2)
	logger().info("Tests completed and database connection closed.")