cevheri commited on
Commit
5066c0b
·
1 Parent(s): b27b4c3

feat: add in-memory database usage with mongomock-motor

Browse files
.env.sample CHANGED
@@ -7,10 +7,8 @@ SECRET_KEY="1234"
7
  API_KEY="sk-lokumai=="
8
  BASE_URL="http://0.0.0.0:7860"
9
 
10
- # mongodb, mongita_memory, mongita_disk
11
- DB_DATABASE_TYPE=mongita_memory
12
- # when type is disk
13
- DB_MONGITA_DB_PATH="./local_dev.db"
14
 
15
  DB_DATABASE_NAME=lokumai
16
  DB_MONGO_URI=mongodb://localhost:27017
 
7
  API_KEY="sk-lokumai=="
8
  BASE_URL="http://0.0.0.0:7860"
9
 
10
+ # mongodb, embedded
11
+ DB_DATABASE_TYPE=embedded
 
 
12
 
13
  DB_DATABASE_NAME=lokumai
14
  DB_MONGO_URI=mongodb://localhost:27017
app/config/db.py CHANGED
@@ -11,9 +11,8 @@ class DBConfig(BaseSettings):
11
  )
12
 
13
 
14
- DATABASE_TYPE: Literal["mongodb", "mongita_disk", "mongita_memory"] = "mongodb"
15
  DATABASE_NAME: str = "openai_chatbot_api"
16
- MONGITA_DB_PATH: Optional[str] = None
17
 
18
  MONGO_USER: str = "root"
19
  MONGO_PASSWORD: str = "rootPass"
 
11
  )
12
 
13
 
14
+ DATABASE_TYPE: str = "mongodb"
15
  DATABASE_NAME: str = "openai_chatbot_api"
 
16
 
17
  MONGO_USER: str = "root"
18
  MONGO_PASSWORD: str = "rootPass"
app/core/db_client.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+ from typing import Optional, Dict, Any
3
+ from motor.motor_asyncio import AsyncIOMotorClient
4
+ from unittest.mock import AsyncMock
5
+
6
+ class DatabaseClient(ABC):
7
+ """Abstract base class for database clients"""
8
+
9
+ @abstractmethod
10
+ async def connect(self) -> None:
11
+ """Connect to the database"""
12
+ pass
13
+
14
+ @abstractmethod
15
+ async def disconnect(self) -> None:
16
+ """Disconnect from the database"""
17
+ pass
18
+
19
+ @abstractmethod
20
+ async def get_database(self) -> Any:
21
+ """Get database instance"""
22
+ pass
23
+
24
+ class MongoClient(DatabaseClient):
25
+ """Real MongoDB client implementation"""
26
+
27
+ def __init__(self, connection_string: str):
28
+ self.connection_string = connection_string
29
+ self._client: Optional[AsyncIOMotorClient] = None
30
+
31
+ async def connect(self) -> None:
32
+ self._client = AsyncIOMotorClient(self.connection_string)
33
+
34
+ async def disconnect(self) -> None:
35
+ if self._client:
36
+ self._client.close()
37
+
38
+ async def get_database(self) -> AsyncIOMotorClient:
39
+ if not self._client:
40
+ await self.connect()
41
+ return self._client
42
+
43
+ class MockMongoClient(DatabaseClient):
44
+ """Mock MongoDB client for testing"""
45
+
46
+ def __init__(self):
47
+ self._client = AsyncMock()
48
+
49
+ async def connect(self) -> None:
50
+ pass
51
+
52
+ async def disconnect(self) -> None:
53
+ pass
54
+
55
+ async def get_database(self) -> AsyncMock:
56
+ return self._client
57
+
58
+ class DatabaseClientFactory:
59
+ """Factory for creating database clients"""
60
+
61
+ _instance: Optional['DatabaseClientFactory'] = None
62
+ _client: Optional[DatabaseClient] = None
63
+
64
+ def __new__(cls):
65
+ if cls._instance is None:
66
+ cls._instance = super().__new__(cls)
67
+ return cls._instance
68
+
69
+ @classmethod
70
+ def create_client(cls, db_type: str, connection_string: Optional[str] = None) -> DatabaseClient:
71
+ """
72
+ Create a database client based on the database type
73
+
74
+ Args:
75
+ db_type: Type of database ('mongodb' or 'mock')
76
+ connection_string: Connection string for the database
77
+
78
+ Returns:
79
+ DatabaseClient: Instance of the appropriate database client
80
+ """
81
+ if cls._client is None:
82
+ if db_type.lower() == 'mongodb':
83
+ if not connection_string:
84
+ raise ValueError("Connection string is required for MongoDB")
85
+ cls._client = MongoClient(connection_string)
86
+ elif db_type.lower() == 'mock':
87
+ cls._client = MockMongoClient()
88
+ else:
89
+ raise ValueError(f"Unsupported database type: {db_type}")
90
+
91
+ return cls._client
92
+
93
+ @classmethod
94
+ async def get_client(cls) -> DatabaseClient:
95
+ """
96
+ Get the current database client instance
97
+
98
+ Returns:
99
+ DatabaseClient: Current database client instance
100
+ """
101
+ if cls._client is None:
102
+ raise RuntimeError("Database client not initialized")
103
+ return cls._client
104
+
105
+ @classmethod
106
+ async def reset_client(cls) -> None:
107
+ """Reset the current database client instance"""
108
+ if cls._client:
109
+ await cls._client.disconnect()
110
+ cls._client = None
app/db/client.py CHANGED
@@ -1,6 +1,8 @@
1
  # mongodb client with motor and pymongo
2
 
 
3
  from motor.motor_asyncio import AsyncIOMotorClient
 
4
  from app.config.db import db_config
5
  from loguru import logger
6
  from typing import Optional
@@ -12,37 +14,64 @@ env.read_env()
12
  DB_DATABASE_TYPE = env.str("DB_DATABASE_TYPE", "mongodb")
13
 
14
 
15
- class MongoDBClient:
16
- _instance: Optional["MongoDBClient"] = None
17
- _client: Optional[AsyncIOMotorClient] = None
18
- _db = None
19
- _is_connected: bool = False
20
 
21
- def __new__(cls):
22
- if cls._instance is None:
23
- cls._instance = super().__new__(cls)
24
- return cls._instance
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
  def __init__(self):
27
- if self._client is None:
28
- self._client = AsyncIOMotorClient(db_config.get_mongo_uri())
29
- self._db = self._client[db_config.DATABASE_NAME]
30
 
31
  @property
32
  def client(self) -> AsyncIOMotorClient:
 
 
 
 
 
 
33
  return self._client
34
 
35
  @property
36
  def db(self):
 
 
 
 
 
37
  return self._db
38
 
39
- async def connect(self):
40
  try:
41
  if not self._is_connected:
42
- if self._client is None:
43
- self._client = AsyncIOMotorClient(db_config.get_mongo_uri())
44
- self._db = self._client[db_config.DATABASE_NAME]
45
- await self._client.server_info()
46
  self._is_connected = True
47
  logger.info("Connected to MongoDB")
48
  except Exception as e:
@@ -50,9 +79,10 @@ class MongoDBClient:
50
  logger.error(f"Failed to connect to MongoDB: {e}")
51
  raise
52
 
53
- async def close(self):
54
  try:
55
  if self._is_connected and self._client is not None:
 
56
  self._client.close()
57
  self._client = None
58
  self._db = None
@@ -65,5 +95,90 @@ class MongoDBClient:
65
  self._is_connected = False
66
 
67
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
  # Global instance
69
- mongodb = MongoDBClient()
 
1
  # mongodb client with motor and pymongo
2
 
3
+ from abc import ABC, abstractmethod
4
  from motor.motor_asyncio import AsyncIOMotorClient
5
+ from mongomock_motor import AsyncMongoMockClient
6
  from app.config.db import db_config
7
  from loguru import logger
8
  from typing import Optional
 
14
  DB_DATABASE_TYPE = env.str("DB_DATABASE_TYPE", "mongodb")
15
 
16
 
17
+ class DatabaseClient(ABC):
18
+ """Abstract base class for database clients"""
 
 
 
19
 
20
+ @abstractmethod
21
+ async def connect(self) -> None:
22
+ """Connect to the database"""
23
+ pass
24
+
25
+ @abstractmethod
26
+ async def close(self) -> None:
27
+ """Close the database connection"""
28
+ pass
29
+
30
+ @property
31
+ @abstractmethod
32
+ def client(self):
33
+ """Get the database client"""
34
+ pass
35
+
36
+ @property
37
+ @abstractmethod
38
+ def db(self):
39
+ """Get the database instance"""
40
+ pass
41
+
42
+
43
+ class PersistentMongoClient(DatabaseClient):
44
+ """Real MongoDB client implementation"""
45
 
46
  def __init__(self):
47
+ self._client: Optional[AsyncIOMotorClient] = None
48
+ self._db = None
49
+ self._is_connected: bool = False
50
 
51
  @property
52
  def client(self) -> AsyncIOMotorClient:
53
+ logger.info("Getting PersistentMongoClient")
54
+ if not self._client:
55
+ logger.info("Generating PersistentMongoClient")
56
+ self._client = AsyncIOMotorClient(db_config.get_mongo_uri())
57
+ self._db = self._client[db_config.DATABASE_NAME]
58
+ logger.info(f"Returning PersistentMongoClient. Host: {self._client.host}")
59
  return self._client
60
 
61
  @property
62
  def db(self):
63
+ logger.info("Getting PersistentMongoClient.db")
64
+ if not self._db:
65
+ logger.info("Generating PersistentMongoClient.db")
66
+ self._db = self.client[db_config.DATABASE_NAME]
67
+ logger.info(f"Returning PersistentMongoClient.db. Host: {self._db.host}")
68
  return self._db
69
 
70
+ async def connect(self) -> None:
71
  try:
72
  if not self._is_connected:
73
+ logger.info("Connecting to MongoDB")
74
+ await self.client.server_info()
 
 
75
  self._is_connected = True
76
  logger.info("Connected to MongoDB")
77
  except Exception as e:
 
79
  logger.error(f"Failed to connect to MongoDB: {e}")
80
  raise
81
 
82
+ async def close(self) -> None:
83
  try:
84
  if self._is_connected and self._client is not None:
85
+ logger.info("Closing MongoDB connection")
86
  self._client.close()
87
  self._client = None
88
  self._db = None
 
95
  self._is_connected = False
96
 
97
 
98
+ class EmbeddedMongoClient(DatabaseClient):
99
+ """Mock MongoDB client implementation for testing"""
100
+
101
+ def __init__(self):
102
+ logger.info("Initializing EmbeddedMongoClient")
103
+ self._client: Optional[AsyncMongoMockClient] = None
104
+ self._db = None
105
+ self._is_connected: bool = False
106
+ logger.info("EmbeddedMongoClient initialized")
107
+
108
+ @property
109
+ def client(self) -> AsyncMongoMockClient:
110
+ logger.info("Getting EmbeddedMongoClient")
111
+ if not self._client:
112
+ logger.info("Generating EmbeddedMongoClient")
113
+ self._client = AsyncMongoMockClient()
114
+ self._db = self._client[db_config.DATABASE_NAME]
115
+ logger.info(f"Returning EmbeddedMongoClient. Host: {self._client.host}")
116
+ return self._client
117
+
118
+ @property
119
+ def db(self):
120
+ logger.info("Getting EmbeddedMongoClient.db")
121
+ if not self._db:
122
+ logger.info("Generating EmbeddedMongoClient.db")
123
+ self._db = self.client[db_config.DATABASE_NAME]
124
+ logger.info(f"Returning EmbeddedMongoClient.db. Host: {self._db.host}")
125
+ return self._db
126
+
127
+ async def connect(self) -> None:
128
+ try:
129
+ if not self._is_connected:
130
+ logger.info("Connecting to EmbeddedMongoClient")
131
+ self._is_connected = True
132
+ logger.info("Connected to EmbeddedMongoClient")
133
+ except Exception as e:
134
+ self._is_connected = False
135
+ logger.error(f"Failed to connect to EmbeddedMongoClient: {e}")
136
+ raise
137
+
138
+ async def close(self) -> None:
139
+ try:
140
+ if self._is_connected and self._client is not None:
141
+ logger.info("Closing EmbeddedMongoClient connection")
142
+ self._client = None
143
+ self._db = None
144
+ self._is_connected = False
145
+ logger.info("Disconnected from EmbeddedMongoClient")
146
+ except Exception as e:
147
+ logger.warning(f"Error while closing EmbeddedMongoClient connection: {e}")
148
+ self._client = None
149
+ self._db = None
150
+ self._is_connected = False
151
+
152
+
153
+ class DatabaseClientFactory:
154
+ """Factory class for creating database clients"""
155
+
156
+ _instance: Optional["DatabaseClientFactory"] = None
157
+ _client: Optional[DatabaseClient] = None
158
+
159
+ def __new__(cls):
160
+ logger.info("Creating DatabaseClientFactory")
161
+ if cls._instance is None:
162
+ cls._instance = super().__new__(cls)
163
+ logger.info("DatabaseClientFactory created")
164
+ logger.info(f"Returning DatabaseClientFactory. Host: {cls._instance.client.host}")
165
+ return cls._instance
166
+
167
+ @classmethod
168
+ def get_client(cls) -> DatabaseClient:
169
+ """Get the appropriate database client based on configuration"""
170
+ logger.info(f"Getting DatabaseClientFactory.client with DB_DATABASE_TYPE: {DB_DATABASE_TYPE}")
171
+ logger.info(f"mongodb uri: {db_config.get_mongo_uri()}")
172
+ if cls._client is None:
173
+ if DB_DATABASE_TYPE == "mongodb":
174
+ logger.info("Creating PersistentMongoClient")
175
+ cls._client = PersistentMongoClient()
176
+ else:
177
+ logger.info("Creating EmbeddedMongoClient")
178
+ cls._client = EmbeddedMongoClient()
179
+ logger.info(f"Returning DatabaseClientFactory.client. Host: {cls._client.client.host}")
180
+ return cls._client
181
+
182
+
183
  # Global instance
184
+ db_client = DatabaseClientFactory.get_client()
app/repository/chat_repository.py CHANGED
@@ -1,11 +1,11 @@
1
  from typing import List
2
- from app.db.client import MongoDBClient
3
  from app.model.chat_model import ChatMessage, ChatCompletion
4
 
5
 
6
  class ChatRepository:
7
  def __init__(self):
8
- self.db = MongoDBClient().db
9
  self.collection = "chat_completion"
10
 
11
  def save(self, entity: ChatCompletion) -> ChatCompletion:
 
1
  from typing import List
2
+ from app.db.client import db_client
3
  from app.model.chat_model import ChatMessage, ChatCompletion
4
 
5
 
6
  class ChatRepository:
7
  def __init__(self):
8
+ self.db = db_client.db
9
  self.collection = "chat_completion"
10
 
11
  def save(self, entity: ChatCompletion) -> ChatCompletion:
main.py CHANGED
@@ -8,7 +8,7 @@ from app.config.log import log_config
8
  from loguru import logger
9
  from environs import Env
10
  from contextlib import asynccontextmanager
11
- from app.db.client import mongodb
12
  from gradio_chatbot import build_gradio_app, app_auth
13
  import gradio as gr
14
  import os
@@ -33,14 +33,12 @@ async def lifespan(app: FastAPI):
33
  """
34
  # Startup
35
  logger.info("Starting up application...")
36
- if DB_DATABASE_TYPE == "mongodb":
37
- await mongodb.connect()
38
  yield
39
 
40
  # Shutdown
41
  logger.info("Shutting down application...")
42
- if DB_DATABASE_TYPE == "mongodb":
43
- await mongodb.close()
44
 
45
 
46
  VERSION = "0.3.0"
 
8
  from loguru import logger
9
  from environs import Env
10
  from contextlib import asynccontextmanager
11
+ from app.db.client import db_client
12
  from gradio_chatbot import build_gradio_app, app_auth
13
  import gradio as gr
14
  import os
 
33
  """
34
  # Startup
35
  logger.info("Starting up application...")
36
+ await db_client.connect()
 
37
  yield
38
 
39
  # Shutdown
40
  logger.info("Shutting down application...")
41
+ await db_client.close()
 
42
 
43
 
44
  VERSION = "0.3.0"
pyproject.toml CHANGED
@@ -11,6 +11,7 @@ dependencies = [
11
  "httpx>=0.28.1",
12
  "loguru>=0.7.3",
13
  "mongita>=1.2.0",
 
14
  "motor>=3.7.1",
15
  "plotly>=6.1.1",
16
  "pydantic>=2.11.4",
 
11
  "httpx>=0.28.1",
12
  "loguru>=0.7.3",
13
  "mongita>=1.2.0",
14
+ "mongomock-motor>=0.0.36",
15
  "motor>=3.7.1",
16
  "plotly>=6.1.1",
17
  "pydantic>=2.11.4",
uv.lock CHANGED
@@ -439,6 +439,33 @@ dependencies = [
439
  ]
440
  sdist = { url = "https://files.pythonhosted.org/packages/a4/b6/9159dca5e74e497b19ab8e68c369f789dc2167d2c8cccb5fbd21377ba040/mongita-1.2.0.tar.gz", hash = "sha256:27487d1deb40d83e4ecb107c3d9d732d369a83cc5f10599ba86fc024e2ce9cda", size = 54582 }
441
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
442
  [[package]]
443
  name = "motor"
444
  version = "3.7.1"
@@ -509,6 +536,7 @@ dependencies = [
509
  { name = "httpx" },
510
  { name = "loguru" },
511
  { name = "mongita" },
 
512
  { name = "motor" },
513
  { name = "plotly" },
514
  { name = "pydantic" },
@@ -526,6 +554,7 @@ requires-dist = [
526
  { name = "httpx", specifier = ">=0.28.1" },
527
  { name = "loguru", specifier = ">=0.7.3" },
528
  { name = "mongita", specifier = ">=1.2.0" },
 
529
  { name = "motor", specifier = ">=3.7.1" },
530
  { name = "plotly", specifier = ">=6.1.1" },
531
  { name = "pydantic", specifier = ">=2.11.4" },
@@ -936,6 +965,12 @@ wheels = [
936
  { url = "https://files.pythonhosted.org/packages/6a/23/8146aad7d88f4fcb3a6218f41a60f6c2d4e3a72de72da1825dc7c8f7877c/semantic_version-2.10.0-py2.py3-none-any.whl", hash = "sha256:de78a3b8e0feda74cabc54aab2da702113e33ac9d9eb9d2389bcf1f58b7d9177", size = 15552 },
937
  ]
938
 
 
 
 
 
 
 
939
  [[package]]
940
  name = "shellingham"
941
  version = "1.5.4"
 
439
  ]
440
  sdist = { url = "https://files.pythonhosted.org/packages/a4/b6/9159dca5e74e497b19ab8e68c369f789dc2167d2c8cccb5fbd21377ba040/mongita-1.2.0.tar.gz", hash = "sha256:27487d1deb40d83e4ecb107c3d9d732d369a83cc5f10599ba86fc024e2ce9cda", size = 54582 }
441
 
442
+ [[package]]
443
+ name = "mongomock"
444
+ version = "4.3.0"
445
+ source = { registry = "https://pypi.org/simple" }
446
+ dependencies = [
447
+ { name = "packaging" },
448
+ { name = "pytz" },
449
+ { name = "sentinels" },
450
+ ]
451
+ sdist = { url = "https://files.pythonhosted.org/packages/4d/a4/4a560a9f2a0bec43d5f63104f55bc48666d619ca74825c8ae156b08547cf/mongomock-4.3.0.tar.gz", hash = "sha256:32667b79066fabc12d4f17f16a8fd7361b5f4435208b3ba32c226e52212a8c30", size = 135862 }
452
+ wheels = [
453
+ { url = "https://files.pythonhosted.org/packages/94/4d/8bea712978e3aff017a2ab50f262c620e9239cc36f348aae45e48d6a4786/mongomock-4.3.0-py2.py3-none-any.whl", hash = "sha256:5ef86bd12fc8806c6e7af32f21266c61b6c4ba96096f85129852d1c4fec1327e", size = 64891 },
454
+ ]
455
+
456
+ [[package]]
457
+ name = "mongomock-motor"
458
+ version = "0.0.36"
459
+ source = { registry = "https://pypi.org/simple" }
460
+ dependencies = [
461
+ { name = "mongomock" },
462
+ { name = "motor" },
463
+ ]
464
+ sdist = { url = "https://files.pythonhosted.org/packages/18/9f/38e42a34ebad323addaf6296d6b5d83eaf2c423adf206b757c68315e196a/mongomock_motor-0.0.36.tar.gz", hash = "sha256:3cf62352ece5af2f02e04d2f252393f88b5fe0487997da00584020cee4b8efba", size = 5754 }
465
+ wheels = [
466
+ { url = "https://files.pythonhosted.org/packages/d6/99/f5fdbbdc96bfd03e5f9c36339547a9076f5dbb5882900b7621526d41a38d/mongomock_motor-0.0.36-py3-none-any.whl", hash = "sha256:3ecb7949662b8986ff9c267fa0b1402b5b75a6afd57f03850cd6e13a067e3691", size = 7334 },
467
+ ]
468
+
469
  [[package]]
470
  name = "motor"
471
  version = "3.7.1"
 
536
  { name = "httpx" },
537
  { name = "loguru" },
538
  { name = "mongita" },
539
+ { name = "mongomock-motor" },
540
  { name = "motor" },
541
  { name = "plotly" },
542
  { name = "pydantic" },
 
554
  { name = "httpx", specifier = ">=0.28.1" },
555
  { name = "loguru", specifier = ">=0.7.3" },
556
  { name = "mongita", specifier = ">=1.2.0" },
557
+ { name = "mongomock-motor", specifier = ">=0.0.36" },
558
  { name = "motor", specifier = ">=3.7.1" },
559
  { name = "plotly", specifier = ">=6.1.1" },
560
  { name = "pydantic", specifier = ">=2.11.4" },
 
965
  { url = "https://files.pythonhosted.org/packages/6a/23/8146aad7d88f4fcb3a6218f41a60f6c2d4e3a72de72da1825dc7c8f7877c/semantic_version-2.10.0-py2.py3-none-any.whl", hash = "sha256:de78a3b8e0feda74cabc54aab2da702113e33ac9d9eb9d2389bcf1f58b7d9177", size = 15552 },
966
  ]
967
 
968
+ [[package]]
969
+ name = "sentinels"
970
+ version = "1.0.0"
971
+ source = { registry = "https://pypi.org/simple" }
972
+ sdist = { url = "https://files.pythonhosted.org/packages/ac/b7/1af07a98390aba07da31807f3723e7bbd003d6441b4b3d67b20d97702b23/sentinels-1.0.0.tar.gz", hash = "sha256:7be0704d7fe1925e397e92d18669ace2f619c92b5d4eb21a89f31e026f9ff4b1", size = 4074 }
973
+
974
  [[package]]
975
  name = "shellingham"
976
  version = "1.5.4"