dylanglenister commited on
Commit
969f308
·
1 Parent(s): f15fec7

Updated session tests

Browse files
Files changed (1) hide show
  1. tests/test_session.py +117 -124
tests/test_session.py CHANGED
@@ -1,8 +1,10 @@
1
  import time
2
  import unittest
3
  from datetime import datetime, timedelta, timezone
 
4
 
5
  from bson import ObjectId
 
6
 
7
  from src.data.connection import (ActionFailed, Collections, close_connection,
8
  get_collection)
@@ -12,13 +14,13 @@ from tests.base_test import BaseMongoTest
12
 
13
 
14
  class TestSessionRepository(BaseMongoTest):
 
15
 
16
  def setUp(self):
17
  """Set up a clean test environment before each test."""
18
  super().setUp()
19
  self.test_collection = self._collections[Collections.SESSION]
20
  session_repo.init(collection_name=self.test_collection, drop=True)
21
-
22
  self.account_id = str(ObjectId())
23
  self.patient_id = str(ObjectId())
24
 
@@ -30,155 +32,146 @@ class TestSessionRepository(BaseMongoTest):
30
 
31
  def test_create_and_get_session(self):
32
  """Test chat session creation and retrieval by ID."""
33
- # Test creation
34
  session = session_repo.create_session(
35
- self.account_id,
36
- self.patient_id,
37
- "Test Chat",
38
- collection_name=self.test_collection
39
  )
40
- self.assertIn("_id", session)
41
  self.assertIsInstance(session["_id"], str)
42
- self.assertEqual(session["title"], "Test Chat")
43
- self.assertEqual(len(session["messages"]), 0)
44
-
45
- # Test retrieval
46
  retrieved = session_repo.get_session(session["_id"], collection_name=self.test_collection)
47
  self.assertIsNotNone(retrieved)
48
  self.assertEqual(retrieved["_id"], session["_id"]) # type: ignore
49
- self.assertEqual(retrieved["account_id"], self.account_id) # type: ignore
50
- self.assertEqual(retrieved["patient_id"], self.patient_id) # type: ignore
51
-
52
  # Test getting a non-existent session
53
- non_existent = session_repo.get_session(str(ObjectId()), collection_name=self.test_collection)
54
- self.assertIsNone(non_existent)
55
 
56
  def test_add_and_get_messages(self):
57
  """Test adding messages and retrieving them in the correct order."""
58
- session = session_repo.create_session(
59
- self.account_id, self.patient_id, "Message Test", collection_name=self.test_collection
60
- )
61
- session_id = session["_id"]
62
-
63
- # Add messages and verify session's updated_at timestamp changes
64
- original_doc = self.get_doc_by_id(Collections.SESSION, session_id)
65
- time.sleep(0.01) # Ensure timestamp will be different
66
- session_repo.add_message(session_id, "User message 1", True, collection_name=self.test_collection)
67
- updated_doc = self.get_doc_by_id(Collections.SESSION, session_id)
68
- self.assertLess(original_doc["updated_at"], updated_doc["updated_at"]) # type: ignore
69
-
70
- session_repo.add_message(session_id, "AI response 1", False, collection_name=self.test_collection)
71
- session_repo.add_message(session_id, "User message 2", True, collection_name=self.test_collection)
72
-
73
- # Test message retrieval (should be in descending order of creation)
74
- messages = session_repo.get_session_messages(session_id, collection_name=self.test_collection)
75
- self.assertEqual(len(messages), 3)
76
- self.assertEqual(messages[0]["_id"], 2)
77
- self.assertEqual(messages[0]["content"], "User message 2")
78
- self.assertEqual(messages[1]["_id"], 1)
79
- self.assertEqual(messages[2]["_id"], 0)
80
-
81
  # Test limit
82
- limited_messages = session_repo.get_session_messages(session_id, limit=2, collection_name=self.test_collection)
83
- self.assertEqual(len(limited_messages), 2)
84
- self.assertEqual(limited_messages[0]["_id"], 2)
85
-
86
- # Test adding message to non-existent session
87
- with self.assertRaises(ActionFailed):
88
- session_repo.add_message(str(ObjectId()), "ghost", True, collection_name=self.test_collection)
89
-
90
- def test_list_patient_sessions(self):
91
- """Test listing sessions for a specific patient, sorted by update time."""
92
- p_id_1 = str(ObjectId())
93
- p_id_2 = str(ObjectId())
94
-
95
- # Create sessions, sleeping briefly to ensure distinct updated_at times
96
- session_repo.create_session(self.account_id, p_id_1, "P1 Chat 1", collection_name=self.test_collection)
97
- time.sleep(0.01)
98
- session_repo.create_session(self.account_id, p_id_2, "P2 Chat 1", collection_name=self.test_collection) # Belongs to other patient
99
- time.sleep(0.01)
100
- s2 = session_repo.create_session(self.account_id, p_id_1, "P1 Chat 2", collection_name=self.test_collection)
101
-
102
- # Test listing for patient 1
103
- sessions = session_repo.list_patient_sessions(p_id_1, collection_name=self.test_collection)
104
- self.assertEqual(len(sessions), 2)
105
- self.assertEqual(sessions[0]["_id"], s2["_id"]) # Most recently created should be first
106
-
107
- def test_get_user_sessions(self):
108
- """Test listing sessions for a specific user, sorted by update time."""
109
- user1 = str(ObjectId())
110
- user2 = str(ObjectId())
111
-
112
- s1 = session_repo.create_session(user1, self.patient_id, "U1 Chat 1", collection_name=self.test_collection)
113
- time.sleep(0.01)
114
- session_repo.create_session(user2, self.patient_id, "U2 Chat 1", collection_name=self.test_collection)
115
- time.sleep(0.01)
116
- s3 = session_repo.create_session(user1, self.patient_id, "U1 Chat 2", collection_name=self.test_collection)
117
-
118
- sessions = session_repo.get_user_sessions(user1, collection_name=self.test_collection)
119
- self.assertEqual(len(sessions), 2)
120
- self.assertEqual(sessions[0]["_id"], s3["_id"]) # s3 was updated most recently
121
- self.assertEqual(sessions[1]["_id"], s1["_id"])
122
-
123
- # Test limit
124
- sessions_limited = session_repo.get_user_sessions(user1, limit=1, collection_name=self.test_collection)
125
- self.assertEqual(len(sessions_limited), 1)
126
 
127
  def test_update_session_title(self):
128
- """Test updating a session's title."""
129
- session = session_repo.create_session(self.account_id, self.patient_id, "Old Title", collection_name=self.test_collection)
130
- session_id = session["_id"]
131
-
132
- success = session_repo.update_session_title(session_id, "New Title", collection_name=self.test_collection)
133
  self.assertTrue(success)
134
-
135
- updated_session = session_repo.get_session(session_id, collection_name=self.test_collection)
136
- self.assertEqual(updated_session["title"], "New Title") # type: ignore
137
-
138
- # Test updating non-existent session
139
- success_fail = session_repo.update_session_title(str(ObjectId()), "Ghost", collection_name=self.test_collection)
140
- self.assertFalse(success_fail)
141
 
142
  def test_delete_session(self):
143
  """Test deleting a session."""
144
  session = session_repo.create_session(self.account_id, self.patient_id, "To Delete", collection_name=self.test_collection)
145
- session_id = session["_id"]
146
-
147
- success = session_repo.delete_session(session_id, collection_name=self.test_collection)
148
- self.assertTrue(success)
149
-
150
- deleted_session = session_repo.get_session(session_id, collection_name=self.test_collection)
151
- self.assertIsNone(deleted_session)
152
-
153
- # Test deleting non-existent session
154
- success_fail = session_repo.delete_session(str(ObjectId()), collection_name=self.test_collection)
155
- self.assertFalse(success_fail)
156
 
157
  def test_prune_old_sessions(self):
158
  """Test deleting sessions older than a specified number of days."""
159
- coll = get_collection(self.test_collection)
160
- now = datetime.now(timezone.utc)
161
- old_date = now - timedelta(days=31)
162
-
163
- # Manually insert one old and one new session
164
- coll.insert_one({
165
- "account_id": ObjectId(self.account_id), "patient_id": ObjectId(self.patient_id),
166
- "title": "Old Session", "created_at": old_date, "updated_at": old_date, "messages": []
167
- })
168
- coll.insert_one({
169
- "account_id": ObjectId(self.account_id), "patient_id": ObjectId(self.patient_id),
170
- "title": "New Session", "created_at": now, "updated_at": now, "messages": []
171
- })
172
-
173
- self.assertEqual(coll.count_documents({}), 2)
174
 
 
175
  deleted_count = session_repo.prune_old_sessions(days=30, collection_name=self.test_collection)
176
  self.assertEqual(deleted_count, 1)
177
- self.assertEqual(coll.count_documents({}), 1)
178
 
179
- remaining = coll.find_one()
180
- self.assertEqual(remaining["title"], "New Session") # type: ignore
181
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
182
 
183
  if __name__ == "__main__":
184
  try:
 
1
  import time
2
  import unittest
3
  from datetime import datetime, timedelta, timezone
4
+ from unittest.mock import patch
5
 
6
  from bson import ObjectId
7
+ from pymongo.errors import ConnectionFailure, WriteError
8
 
9
  from src.data.connection import (ActionFailed, Collections, close_connection,
10
  get_collection)
 
14
 
15
 
16
  class TestSessionRepository(BaseMongoTest):
17
+ """Test class for the 'happy path' and edge cases of session repository functions."""
18
 
19
  def setUp(self):
20
  """Set up a clean test environment before each test."""
21
  super().setUp()
22
  self.test_collection = self._collections[Collections.SESSION]
23
  session_repo.init(collection_name=self.test_collection, drop=True)
 
24
  self.account_id = str(ObjectId())
25
  self.patient_id = str(ObjectId())
26
 
 
32
 
33
  def test_create_and_get_session(self):
34
  """Test chat session creation and retrieval by ID."""
 
35
  session = session_repo.create_session(
36
+ self.account_id, self.patient_id, "Test Chat", collection_name=self.test_collection
 
 
 
37
  )
 
38
  self.assertIsInstance(session["_id"], str)
 
 
 
 
39
  retrieved = session_repo.get_session(session["_id"], collection_name=self.test_collection)
40
  self.assertIsNotNone(retrieved)
41
  self.assertEqual(retrieved["_id"], session["_id"]) # type: ignore
 
 
 
42
  # Test getting a non-existent session
43
+ self.assertIsNone(session_repo.get_session(str(ObjectId()), collection_name=self.test_collection))
 
44
 
45
  def test_add_and_get_messages(self):
46
  """Test adding messages and retrieving them in the correct order."""
47
+ session = session_repo.create_session(self.account_id, self.patient_id, "Msg Test", collection_name=self.test_collection)
48
+ session_repo.add_message(session["_id"], "User message 1", True, collection_name=self.test_collection)
49
+ session_repo.add_message(session["_id"], "AI response 1", False, collection_name=self.test_collection)
50
+ messages = session_repo.get_session_messages(session["_id"], collection_name=self.test_collection)
51
+ self.assertEqual(len(messages), 2)
52
+ self.assertEqual(messages[0]["content"], "AI response 1") # Descending order
53
+ self.assertEqual(messages[1]["_id"], 0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
  # Test limit
55
+ self.assertEqual(len(session_repo.get_session_messages(session["_id"], limit=1, collection_name=self.test_collection)), 1)
56
+
57
+ def test_list_sessions(self):
58
+ """Test listing sessions for a specific patient and user, sorted by update time."""
59
+ session_repo.create_session(self.account_id, self.patient_id, "First", collection_name=self.test_collection)
60
+ time.sleep(0.01) # Ensure distinct timestamps
61
+ s2 = session_repo.create_session(self.account_id, self.patient_id, "Second", collection_name=self.test_collection)
62
+ # Test listing for patient
63
+ patient_sessions = session_repo.list_patient_sessions(self.patient_id, collection_name=self.test_collection)
64
+ self.assertEqual(len(patient_sessions), 2)
65
+ self.assertEqual(patient_sessions[0]["_id"], s2["_id"]) # Most recent first
66
+ # Test listing for user
67
+ user_sessions = session_repo.get_user_sessions(self.account_id, collection_name=self.test_collection)
68
+ self.assertEqual(len(user_sessions), 2)
69
+ self.assertEqual(user_sessions[0]["_id"], s2["_id"])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
 
71
  def test_update_session_title(self):
72
+ """Test updating a session's title and its timestamp."""
73
+ session = session_repo.create_session(self.account_id, self.patient_id, "Old", collection_name=self.test_collection)
74
+ original_doc = self.get_doc_by_id(Collections.SESSION, session["_id"])
75
+ success = session_repo.update_session_title(session["_id"], "New", collection_name=self.test_collection)
 
76
  self.assertTrue(success)
77
+ updated_doc = self.get_doc_by_id(Collections.SESSION, session["_id"])
78
+ self.assertEqual(updated_doc["title"], "New") # type: ignore
79
+ self.assertLess(original_doc["updated_at"], updated_doc["updated_at"]) # type: ignore
80
+ self.assertFalse(session_repo.update_session_title(str(ObjectId()), "Ghost", collection_name=self.test_collection))
 
 
 
81
 
82
  def test_delete_session(self):
83
  """Test deleting a session."""
84
  session = session_repo.create_session(self.account_id, self.patient_id, "To Delete", collection_name=self.test_collection)
85
+ self.assertTrue(session_repo.delete_session(session["_id"], collection_name=self.test_collection))
86
+ self.assertIsNone(session_repo.get_session(session["_id"], collection_name=self.test_collection))
87
+ self.assertFalse(session_repo.delete_session(str(ObjectId()), collection_name=self.test_collection))
 
 
 
 
 
 
 
 
88
 
89
  def test_prune_old_sessions(self):
90
  """Test deleting sessions older than a specified number of days."""
91
+ # Create two valid sessions
92
+ old_session = session_repo.create_session(self.account_id, self.patient_id, "Old Session", collection_name=self.test_collection)
93
+ session_repo.create_session(self.account_id, self.patient_id, "New Session", collection_name=self.test_collection)
94
+
95
+ # Manually update one session to be old
96
+ old_date = datetime.now(timezone.utc) - timedelta(days=31)
97
+ get_collection(self.test_collection).update_one(
98
+ {"_id": ObjectId(old_session["_id"])},
99
+ {"$set": {"updated_at": old_date}}
100
+ )
 
 
 
 
 
101
 
102
+ self.assertEqual(get_collection(self.test_collection).count_documents({}), 2)
103
  deleted_count = session_repo.prune_old_sessions(days=30, collection_name=self.test_collection)
104
  self.assertEqual(deleted_count, 1)
105
+ self.assertEqual(get_collection(self.test_collection).count_documents({}), 1)
106
 
 
 
107
 
108
+ class TestSessionRepositoryExceptions(BaseMongoTest):
109
+ """Test class for the exception handling of all session repository functions."""
110
+
111
+ def setUp(self):
112
+ """Set up a clean test environment before each test."""
113
+ super().setUp()
114
+ self.test_collection = self._collections[Collections.SESSION]
115
+ session_repo.init(collection_name=self.test_collection, drop=True)
116
+ self.account_id = str(ObjectId())
117
+ self.patient_id = str(ObjectId())
118
+
119
+ @patch('src.data.repositories.session.get_collection')
120
+ def test_write_error_raises_action_failed(self, mock_get_collection):
121
+ """Test that a WriteError during an operation is raised as ActionFailed."""
122
+ # Configure the mock to return a collection object whose methods raise errors
123
+ mock_collection = mock_get_collection.return_value
124
+ mock_collection.update_one.side_effect = WriteError("Simulated schema validation error")
125
+ mock_collection.find_one.return_value = {"messages": []} # Needed for add_message to proceed
126
+
127
+ with self.assertRaises(ActionFailed):
128
+ # This will fail inside at the update_one call, which we've mocked
129
+ session_repo.add_message("68e212480769b3f99015f43c", "content", True, collection_name=self.test_collection)
130
+
131
+ def test_invalid_id_raises_action_failed(self):
132
+ """Test that functions raise ActionFailed when given a malformed ObjectId string."""
133
+ with self.assertRaises(ActionFailed):
134
+ session_repo.create_session("bad-id", self.patient_id, "t", collection_name=self.test_collection)
135
+ with self.assertRaises(ActionFailed):
136
+ session_repo.get_user_sessions("bad-id", collection_name=self.test_collection)
137
+ with self.assertRaises(ActionFailed):
138
+ session_repo.list_patient_sessions("bad-id", collection_name=self.test_collection)
139
+ with self.assertRaises(ActionFailed):
140
+ session_repo.get_session("bad-id", collection_name=self.test_collection)
141
+ with self.assertRaises(ActionFailed):
142
+ session_repo.get_session_messages("bad-id", collection_name=self.test_collection)
143
+ with self.assertRaises(ActionFailed):
144
+ session_repo.update_session_title("bad-id", "t", collection_name=self.test_collection)
145
+ with self.assertRaises(ActionFailed):
146
+ session_repo.delete_session("bad-id", collection_name=self.test_collection)
147
+ with self.assertRaises(ActionFailed):
148
+ session_repo.add_message("bad-id", "t", True, collection_name=self.test_collection)
149
+
150
+ @patch('src.data.repositories.session.get_collection')
151
+ def test_all_functions_raise_on_connection_error(self, mock_get_collection):
152
+ """Test that all repo functions catch generic PyMongoErrors and raise ActionFailed."""
153
+ mock_get_collection.side_effect = ConnectionFailure("Simulated connection error")
154
+
155
+ with self.assertRaises(ActionFailed):
156
+ session_repo.init(collection_name=self.test_collection, drop=True)
157
+ with self.assertRaises(ActionFailed):
158
+ session_repo.create_session(self.account_id, self.patient_id, "t", collection_name=self.test_collection)
159
+ with self.assertRaises(ActionFailed):
160
+ session_repo.get_user_sessions(self.account_id, collection_name=self.test_collection)
161
+ with self.assertRaises(ActionFailed):
162
+ session_repo.list_patient_sessions(self.patient_id, collection_name=self.test_collection)
163
+ with self.assertRaises(ActionFailed):
164
+ session_repo.get_session(str(ObjectId()), collection_name=self.test_collection)
165
+ with self.assertRaises(ActionFailed):
166
+ session_repo.get_session_messages(str(ObjectId()), collection_name=self.test_collection)
167
+ with self.assertRaises(ActionFailed):
168
+ session_repo.update_session_title(str(ObjectId()), "t", collection_name=self.test_collection)
169
+ with self.assertRaises(ActionFailed):
170
+ session_repo.delete_session(str(ObjectId()), collection_name=self.test_collection)
171
+ with self.assertRaises(ActionFailed):
172
+ session_repo.prune_old_sessions(collection_name=self.test_collection)
173
+ with self.assertRaises(ActionFailed):
174
+ session_repo.add_message(str(ObjectId()), "t", True, collection_name=self.test_collection)
175
 
176
  if __name__ == "__main__":
177
  try: