dylanglenister commited on
Commit
8cd3eec
·
1 Parent(s): 5fe2c55

Updated utils and created test file

Browse files
Files changed (2) hide show
  1. src/data/utils.py +64 -24
  2. tests/test_utils.py +82 -0
src/data/utils.py CHANGED
@@ -1,10 +1,9 @@
1
- # data/repositories/utils.py
2
 
3
  from datetime import datetime, timedelta, timezone
4
 
5
  from pymongo import ASCENDING
6
- from pymongo.errors import (ConnectionFailure, DuplicateKeyError,
7
- OperationFailure, PyMongoError)
8
 
9
  from src.data.connection import ActionFailed, get_collection, get_database
10
  from src.utils.logger import logger
@@ -16,35 +15,76 @@ def create_index(
16
  *,
17
  unique: bool = False
18
  ) -> None:
19
- """Creates an index on a specified collection."""
20
- collection = get_collection(collection_name)
21
- collection.create_index([(field_name, ASCENDING)], unique=unique)
 
 
 
 
 
 
 
 
 
 
22
 
23
  def delete_old_data(
24
  collection_name: str,
 
25
  *,
26
  days: int = 30
27
  ) -> int:
28
- """Deletes data older than a specified number of days."""
29
- collection = get_collection(collection_name)
30
- cutoff = datetime.now(timezone.utc) - timedelta(days=days)
31
- result = collection.delete_many({
32
- "updated_at": {"$lt": cutoff}
33
- })
34
- return result.deleted_count
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
  def backup_collection(collection_name: str) -> str:
37
- """Creates a timestamped backup of a collection using an aggregation pipeline."""
38
- db = get_database()
39
- backup_name = f"{collection_name}_backup_{datetime.now(timezone.utc).strftime('%Y%m%d_%H%M%S')}"
 
 
40
 
41
- if backup_name in db.list_collection_names():
42
- db.drop_collection(backup_name)
 
 
 
 
43
 
44
- source_collection = get_collection(collection_name)
45
- pipeline = [{"$match": {}}, {"$out": backup_name}]
46
- source_collection.aggregate(pipeline)
 
 
47
 
48
- doc_count = db[backup_name].count_documents({})
49
- logger().info(f"Created backup '{backup_name}' with {doc_count} documents.")
50
- return backup_name
 
 
 
 
1
+ # src/data/repositories/utils.py
2
 
3
  from datetime import datetime, timedelta, timezone
4
 
5
  from pymongo import ASCENDING
6
+ from pymongo.errors import ConnectionFailure, PyMongoError
 
7
 
8
  from src.data.connection import ActionFailed, get_collection, get_database
9
  from src.utils.logger import logger
 
15
  *,
16
  unique: bool = False
17
  ) -> None:
18
+ """
19
+ Creates an index on a specified collection.
20
+
21
+ Raises:
22
+ ActionFailed: If a database error occurs.
23
+ """
24
+ try:
25
+ collection = get_collection(collection_name)
26
+ collection.create_index([(field_name, ASCENDING)], unique=unique)
27
+ logger().info(f"Ensured index exists on '{field_name}' for collection '{collection_name}'.")
28
+ except (ConnectionFailure, PyMongoError) as e:
29
+ logger().error(f"Failed to create index on '{collection_name}': {e}")
30
+ raise ActionFailed("A database error occurred while creating an index.") from e
31
 
32
  def delete_old_data(
33
  collection_name: str,
34
+ timestamp_field: str = "updated_at",
35
  *,
36
  days: int = 30
37
  ) -> int:
38
+ """
39
+ Deletes documents from a collection older than a specified number of days.
40
+
41
+ Args:
42
+ collection_name: The name of the collection to prune.
43
+ timestamp_field: The name of the datetime field to check. Defaults to "updated_at".
44
+ days: The age in days beyond which documents will be deleted.
45
+
46
+ Returns:
47
+ The number of documents deleted.
48
+
49
+ Raises:
50
+ ActionFailed: If a database error occurs.
51
+ """
52
+ try:
53
+ collection = get_collection(collection_name)
54
+ cutoff = datetime.now(timezone.utc) - timedelta(days=days)
55
+ result = collection.delete_many({
56
+ timestamp_field: {"$lt": cutoff}
57
+ })
58
+ if result.deleted_count > 0:
59
+ logger().info(f"Deleted {result.deleted_count} old documents from '{collection_name}'.")
60
+ return result.deleted_count
61
+ except (ConnectionFailure, PyMongoError) as e:
62
+ logger().error(f"Failed to delete old data from '{collection_name}': {e}")
63
+ raise ActionFailed("A database error occurred while deleting old data.") from e
64
 
65
  def backup_collection(collection_name: str) -> str:
66
+ """
67
+ Creates a timestamped backup of a collection using an aggregation pipeline.
68
+
69
+ Returns:
70
+ The name of the newly created backup collection.
71
 
72
+ Raises:
73
+ ActionFailed: If a database error occurs.
74
+ """
75
+ try:
76
+ db = get_database()
77
+ backup_name = f"{collection_name}_backup_{datetime.now(timezone.utc).strftime('%Y%m%d_%H%M%S')}"
78
 
79
+ # This operation is idempotent, so no need to check for existence first.
80
+ # The $out stage will automatically replace the collection if it exists.
81
+ source_collection = get_collection(collection_name)
82
+ pipeline = [{"$match": {}}, {"$out": backup_name}]
83
+ source_collection.aggregate(pipeline)
84
 
85
+ doc_count = db[backup_name].count_documents({})
86
+ logger().info(f"Created backup '{backup_name}' with {doc_count} documents.")
87
+ return backup_name
88
+ except (ConnectionFailure, PyMongoError) as e:
89
+ logger().error(f"Failed to back up collection '{collection_name}': {e}")
90
+ raise ActionFailed("A database error occurred during collection backup.") from e
tests/test_utils.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import unittest
2
+ from datetime import datetime, timedelta, timezone
3
+ from unittest.mock import patch
4
+
5
+ from pymongo.errors import ConnectionFailure
6
+
7
+ from src.data import utils as db_utils
8
+ from src.data.connection import ActionFailed, Collections, get_collection
9
+ from src.utils.logger import logger
10
+ from tests.base_test import BaseMongoTest
11
+
12
+
13
+ class TestDatabaseUtils(BaseMongoTest):
14
+ """Test class for the 'happy path' of all database utility functions."""
15
+
16
+ def setUp(self):
17
+ """Set up a clean test environment before each test."""
18
+ super().setUp()
19
+ self.test_collection_name = self._collections[Collections.ACCOUNT]
20
+ self.test_collection = get_collection(self.test_collection_name)
21
+
22
+ def test_create_index(self):
23
+ """Test that an index is correctly created on a collection."""
24
+ db_utils.create_index(self.test_collection_name, "test_field")
25
+ index_info = self.test_collection.index_information()
26
+ self.assertIn("test_field_1", index_info)
27
+
28
+ # Test unique index creation
29
+ db_utils.create_index(self.test_collection_name, "unique_field", unique=True)
30
+ index_info_unique = self.test_collection.index_information()
31
+ self.assertTrue(index_info_unique["unique_field_1"]["unique"])
32
+
33
+ def test_delete_old_data(self):
34
+ """Test that only documents older than the cutoff are deleted."""
35
+ now = datetime.now(timezone.utc)
36
+ old_date = now - timedelta(days=31)
37
+
38
+ # Insert one old and one new document
39
+ self.test_collection.insert_one({"name": "old_doc", "updated_at": old_date})
40
+ self.test_collection.insert_one({"name": "new_doc", "updated_at": now})
41
+ self.assertEqual(self.test_collection.count_documents({}), 2)
42
+
43
+ deleted_count = db_utils.delete_old_data(self.test_collection_name, days=30)
44
+ self.assertEqual(deleted_count, 1)
45
+ self.assertEqual(self.test_collection.count_documents({}), 1)
46
+
47
+ remaining_doc = self.test_collection.find_one()
48
+ self.assertEqual(remaining_doc["name"], "new_doc") # type: ignore
49
+
50
+ def test_backup_collection(self):
51
+ """Test that a collection is successfully backed up."""
52
+ self.test_collection.insert_one({"name": "doc1"})
53
+ self.test_collection.insert_one({"name": "doc2"})
54
+
55
+ backup_name = db_utils.backup_collection(self.test_collection_name)
56
+ self.assertIn(backup_name, self.db.list_collection_names())
57
+
58
+ backup_collection = self.db[backup_name]
59
+ self.assertEqual(backup_collection.count_documents({}), 2)
60
+
61
+
62
+ class TestDatabaseUtilsExceptions(BaseMongoTest):
63
+ """Test class for the exception handling of all database utility functions."""
64
+
65
+ @patch('src.data.utils.get_collection')
66
+ @patch('src.data.utils.get_database')
67
+ def test_all_functions_raise_on_connection_error(self, mock_get_database, mock_get_collection):
68
+ """Test that all utility functions catch PyMongoErrors and raise ActionFailed."""
69
+ mock_get_collection.side_effect = ConnectionFailure("Simulated connection error")
70
+ mock_get_database.side_effect = ConnectionFailure("Simulated connection error")
71
+
72
+ with self.assertRaises(ActionFailed):
73
+ db_utils.create_index("any_collection", "any_field")
74
+ with self.assertRaises(ActionFailed):
75
+ db_utils.delete_old_data("any_collection", days=30)
76
+ with self.assertRaises(ActionFailed):
77
+ db_utils.backup_collection("any_collection")
78
+
79
+ if __name__ == "__main__":
80
+ logger().info("Starting MongoDB repository integration tests...")
81
+ unittest.main(verbosity=2)
82
+ logger().info("Tests completed.")