ArnoChen
commited on
Commit
·
e1a2519
1
Parent(s):
65fd96e
improve conditional checks for db instance
Browse files- lightrag/kg/mongo_impl.py +31 -11
- lightrag/kg/oracle_impl.py +12 -9
- lightrag/kg/postgres_impl.py +16 -11
- lightrag/kg/tidb_impl.py +11 -11
lightrag/kg/mongo_impl.py
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
import os
|
2 |
-
from dataclasses import dataclass
|
3 |
import numpy as np
|
4 |
import configparser
|
5 |
from tqdm.asyncio import tqdm as tqdm_async
|
@@ -27,7 +27,11 @@ if not pm.is_installed("motor"):
|
|
27 |
pm.install("motor")
|
28 |
|
29 |
try:
|
30 |
-
from motor.motor_asyncio import
|
|
|
|
|
|
|
|
|
31 |
from pymongo.operations import SearchIndexModel
|
32 |
from pymongo.errors import PyMongoError
|
33 |
except ImportError as e:
|
@@ -79,19 +83,23 @@ class ClientManager:
|
|
79 |
@final
|
80 |
@dataclass
|
81 |
class MongoKVStorage(BaseKVStorage):
|
|
|
|
|
|
|
82 |
def __post_init__(self):
|
83 |
self._collection_name = self.namespace
|
84 |
|
85 |
async def initialize(self):
|
86 |
-
if
|
87 |
self.db = await ClientManager.get_client()
|
88 |
self._data = await get_or_create_collection(self.db, self._collection_name)
|
89 |
logger.debug(f"Use MongoDB as KV {self._collection_name}")
|
90 |
|
91 |
async def finalize(self):
|
92 |
-
if
|
93 |
await ClientManager.release_client(self.db)
|
94 |
self.db = None
|
|
|
95 |
|
96 |
async def get_by_id(self, id: str) -> dict[str, Any] | None:
|
97 |
return await self._data.find_one({"_id": id})
|
@@ -148,19 +156,23 @@ class MongoKVStorage(BaseKVStorage):
|
|
148 |
@final
|
149 |
@dataclass
|
150 |
class MongoDocStatusStorage(DocStatusStorage):
|
|
|
|
|
|
|
151 |
def __post_init__(self):
|
152 |
self._collection_name = self.namespace
|
153 |
|
154 |
async def initialize(self):
|
155 |
-
if
|
156 |
self.db = await ClientManager.get_client()
|
157 |
self._data = await get_or_create_collection(self.db, self._collection_name)
|
158 |
logger.debug(f"Use MongoDB as DocStatus {self._collection_name}")
|
159 |
|
160 |
async def finalize(self):
|
161 |
-
if
|
162 |
await ClientManager.release_client(self.db)
|
163 |
self.db = None
|
|
|
164 |
|
165 |
async def get_by_id(self, id: str) -> Union[dict[str, Any], None]:
|
166 |
return await self._data.find_one({"_id": id})
|
@@ -221,9 +233,12 @@ class MongoDocStatusStorage(DocStatusStorage):
|
|
221 |
@dataclass
|
222 |
class MongoGraphStorage(BaseGraphStorage):
|
223 |
"""
|
224 |
-
A concrete implementation using MongoDB
|
225 |
"""
|
226 |
|
|
|
|
|
|
|
227 |
def __init__(self, namespace, global_config, embedding_func):
|
228 |
super().__init__(
|
229 |
namespace=namespace,
|
@@ -233,7 +248,7 @@ class MongoGraphStorage(BaseGraphStorage):
|
|
233 |
self._collection_name = self.namespace
|
234 |
|
235 |
async def initialize(self):
|
236 |
-
if
|
237 |
self.db = await ClientManager.get_client()
|
238 |
self.collection = await get_or_create_collection(
|
239 |
self.db, self._collection_name
|
@@ -241,9 +256,10 @@ class MongoGraphStorage(BaseGraphStorage):
|
|
241 |
logger.debug(f"Use MongoDB as KG {self._collection_name}")
|
242 |
|
243 |
async def finalize(self):
|
244 |
-
if
|
245 |
await ClientManager.release_client(self.db)
|
246 |
self.db = None
|
|
|
247 |
|
248 |
#
|
249 |
# -------------------------------------------------------------------------
|
@@ -782,6 +798,9 @@ class MongoGraphStorage(BaseGraphStorage):
|
|
782 |
@final
|
783 |
@dataclass
|
784 |
class MongoVectorDBStorage(BaseVectorStorage):
|
|
|
|
|
|
|
785 |
def __post_init__(self):
|
786 |
kwargs = self.global_config.get("vector_db_storage_cls_kwargs", {})
|
787 |
cosine_threshold = kwargs.get("cosine_better_than_threshold")
|
@@ -794,7 +813,7 @@ class MongoVectorDBStorage(BaseVectorStorage):
|
|
794 |
self._max_batch_size = self.global_config["embedding_batch_num"]
|
795 |
|
796 |
async def initialize(self):
|
797 |
-
if
|
798 |
self.db = await ClientManager.get_client()
|
799 |
self._data = await get_or_create_collection(self.db, self._collection_name)
|
800 |
|
@@ -804,9 +823,10 @@ class MongoVectorDBStorage(BaseVectorStorage):
|
|
804 |
logger.debug(f"Use MongoDB as VDB {self._collection_name}")
|
805 |
|
806 |
async def finalize(self):
|
807 |
-
if
|
808 |
await ClientManager.release_client(self.db)
|
809 |
self.db = None
|
|
|
810 |
|
811 |
async def create_vector_index_if_not_exists(self):
|
812 |
"""Creates an Atlas Vector Search index."""
|
|
|
1 |
import os
|
2 |
+
from dataclasses import dataclass, field
|
3 |
import numpy as np
|
4 |
import configparser
|
5 |
from tqdm.asyncio import tqdm as tqdm_async
|
|
|
27 |
pm.install("motor")
|
28 |
|
29 |
try:
|
30 |
+
from motor.motor_asyncio import (
|
31 |
+
AsyncIOMotorClient,
|
32 |
+
AsyncIOMotorDatabase,
|
33 |
+
AsyncIOMotorCollection,
|
34 |
+
)
|
35 |
from pymongo.operations import SearchIndexModel
|
36 |
from pymongo.errors import PyMongoError
|
37 |
except ImportError as e:
|
|
|
83 |
@final
|
84 |
@dataclass
|
85 |
class MongoKVStorage(BaseKVStorage):
|
86 |
+
db: AsyncIOMotorDatabase = field(init=False)
|
87 |
+
_data: AsyncIOMotorCollection = field(init=False)
|
88 |
+
|
89 |
def __post_init__(self):
|
90 |
self._collection_name = self.namespace
|
91 |
|
92 |
async def initialize(self):
|
93 |
+
if self.db is None:
|
94 |
self.db = await ClientManager.get_client()
|
95 |
self._data = await get_or_create_collection(self.db, self._collection_name)
|
96 |
logger.debug(f"Use MongoDB as KV {self._collection_name}")
|
97 |
|
98 |
async def finalize(self):
|
99 |
+
if self.db is not None:
|
100 |
await ClientManager.release_client(self.db)
|
101 |
self.db = None
|
102 |
+
self._data = None
|
103 |
|
104 |
async def get_by_id(self, id: str) -> dict[str, Any] | None:
|
105 |
return await self._data.find_one({"_id": id})
|
|
|
156 |
@final
|
157 |
@dataclass
|
158 |
class MongoDocStatusStorage(DocStatusStorage):
|
159 |
+
db: AsyncIOMotorDatabase = field(init=False)
|
160 |
+
_data: AsyncIOMotorCollection = field(init=False)
|
161 |
+
|
162 |
def __post_init__(self):
|
163 |
self._collection_name = self.namespace
|
164 |
|
165 |
async def initialize(self):
|
166 |
+
if self.db is None:
|
167 |
self.db = await ClientManager.get_client()
|
168 |
self._data = await get_or_create_collection(self.db, self._collection_name)
|
169 |
logger.debug(f"Use MongoDB as DocStatus {self._collection_name}")
|
170 |
|
171 |
async def finalize(self):
|
172 |
+
if self.db is not None:
|
173 |
await ClientManager.release_client(self.db)
|
174 |
self.db = None
|
175 |
+
self._data = None
|
176 |
|
177 |
async def get_by_id(self, id: str) -> Union[dict[str, Any], None]:
|
178 |
return await self._data.find_one({"_id": id})
|
|
|
233 |
@dataclass
|
234 |
class MongoGraphStorage(BaseGraphStorage):
|
235 |
"""
|
236 |
+
A concrete implementation using MongoDB's $graphLookup to demonstrate multi-hop queries.
|
237 |
"""
|
238 |
|
239 |
+
db: AsyncIOMotorDatabase = field(init=False)
|
240 |
+
collection: AsyncIOMotorCollection = field(init=False)
|
241 |
+
|
242 |
def __init__(self, namespace, global_config, embedding_func):
|
243 |
super().__init__(
|
244 |
namespace=namespace,
|
|
|
248 |
self._collection_name = self.namespace
|
249 |
|
250 |
async def initialize(self):
|
251 |
+
if self.db is None:
|
252 |
self.db = await ClientManager.get_client()
|
253 |
self.collection = await get_or_create_collection(
|
254 |
self.db, self._collection_name
|
|
|
256 |
logger.debug(f"Use MongoDB as KG {self._collection_name}")
|
257 |
|
258 |
async def finalize(self):
|
259 |
+
if self.db is not None:
|
260 |
await ClientManager.release_client(self.db)
|
261 |
self.db = None
|
262 |
+
self.collection = None
|
263 |
|
264 |
#
|
265 |
# -------------------------------------------------------------------------
|
|
|
798 |
@final
|
799 |
@dataclass
|
800 |
class MongoVectorDBStorage(BaseVectorStorage):
|
801 |
+
db: AsyncIOMotorDatabase = field(init=False)
|
802 |
+
_data: AsyncIOMotorCollection = field(init=False)
|
803 |
+
|
804 |
def __post_init__(self):
|
805 |
kwargs = self.global_config.get("vector_db_storage_cls_kwargs", {})
|
806 |
cosine_threshold = kwargs.get("cosine_better_than_threshold")
|
|
|
813 |
self._max_batch_size = self.global_config["embedding_batch_num"]
|
814 |
|
815 |
async def initialize(self):
|
816 |
+
if self.db is None:
|
817 |
self.db = await ClientManager.get_client()
|
818 |
self._data = await get_or_create_collection(self.db, self._collection_name)
|
819 |
|
|
|
823 |
logger.debug(f"Use MongoDB as VDB {self._collection_name}")
|
824 |
|
825 |
async def finalize(self):
|
826 |
+
if self.db is not None:
|
827 |
await ClientManager.release_client(self.db)
|
828 |
self.db = None
|
829 |
+
self._data = None
|
830 |
|
831 |
async def create_vector_index_if_not_exists(self):
|
832 |
"""Creates an Atlas Vector Search index."""
|
lightrag/kg/oracle_impl.py
CHANGED
@@ -3,7 +3,7 @@ import asyncio
|
|
3 |
|
4 |
# import html
|
5 |
import os
|
6 |
-
from dataclasses import dataclass
|
7 |
from typing import Any, Union, final
|
8 |
import numpy as np
|
9 |
import configparser
|
@@ -242,8 +242,7 @@ class ClientManager:
|
|
242 |
@final
|
243 |
@dataclass
|
244 |
class OracleKVStorage(BaseKVStorage):
|
245 |
-
|
246 |
-
# db: OracleDB
|
247 |
meta_fields = None
|
248 |
|
249 |
def __post_init__(self):
|
@@ -251,11 +250,11 @@ class OracleKVStorage(BaseKVStorage):
|
|
251 |
self._max_batch_size = self.global_config.get("embedding_batch_num", 10)
|
252 |
|
253 |
async def initialize(self):
|
254 |
-
if
|
255 |
self.db = await ClientManager.get_client()
|
256 |
|
257 |
async def finalize(self):
|
258 |
-
if
|
259 |
await ClientManager.release_client(self.db)
|
260 |
self.db = None
|
261 |
|
@@ -395,6 +394,8 @@ class OracleKVStorage(BaseKVStorage):
|
|
395 |
@final
|
396 |
@dataclass
|
397 |
class OracleVectorDBStorage(BaseVectorStorage):
|
|
|
|
|
398 |
def __post_init__(self):
|
399 |
config = self.global_config.get("vector_db_storage_cls_kwargs", {})
|
400 |
cosine_threshold = config.get("cosine_better_than_threshold")
|
@@ -405,11 +406,11 @@ class OracleVectorDBStorage(BaseVectorStorage):
|
|
405 |
self.cosine_better_than_threshold = cosine_threshold
|
406 |
|
407 |
async def initialize(self):
|
408 |
-
if
|
409 |
self.db = await ClientManager.get_client()
|
410 |
|
411 |
async def finalize(self):
|
412 |
-
if
|
413 |
await ClientManager.release_client(self.db)
|
414 |
self.db = None
|
415 |
|
@@ -449,15 +450,17 @@ class OracleVectorDBStorage(BaseVectorStorage):
|
|
449 |
@final
|
450 |
@dataclass
|
451 |
class OracleGraphStorage(BaseGraphStorage):
|
|
|
|
|
452 |
def __post_init__(self):
|
453 |
self._max_batch_size = self.global_config.get("embedding_batch_num", 10)
|
454 |
|
455 |
async def initialize(self):
|
456 |
-
if
|
457 |
self.db = await ClientManager.get_client()
|
458 |
|
459 |
async def finalize(self):
|
460 |
-
if
|
461 |
await ClientManager.release_client(self.db)
|
462 |
self.db = None
|
463 |
|
|
|
3 |
|
4 |
# import html
|
5 |
import os
|
6 |
+
from dataclasses import dataclass, field
|
7 |
from typing import Any, Union, final
|
8 |
import numpy as np
|
9 |
import configparser
|
|
|
242 |
@final
|
243 |
@dataclass
|
244 |
class OracleKVStorage(BaseKVStorage):
|
245 |
+
db: OracleDB = field(init=False)
|
|
|
246 |
meta_fields = None
|
247 |
|
248 |
def __post_init__(self):
|
|
|
250 |
self._max_batch_size = self.global_config.get("embedding_batch_num", 10)
|
251 |
|
252 |
async def initialize(self):
|
253 |
+
if self.db is None:
|
254 |
self.db = await ClientManager.get_client()
|
255 |
|
256 |
async def finalize(self):
|
257 |
+
if self.db is not None:
|
258 |
await ClientManager.release_client(self.db)
|
259 |
self.db = None
|
260 |
|
|
|
394 |
@final
|
395 |
@dataclass
|
396 |
class OracleVectorDBStorage(BaseVectorStorage):
|
397 |
+
db: OracleDB = field(init=False)
|
398 |
+
|
399 |
def __post_init__(self):
|
400 |
config = self.global_config.get("vector_db_storage_cls_kwargs", {})
|
401 |
cosine_threshold = config.get("cosine_better_than_threshold")
|
|
|
406 |
self.cosine_better_than_threshold = cosine_threshold
|
407 |
|
408 |
async def initialize(self):
|
409 |
+
if self.db is None:
|
410 |
self.db = await ClientManager.get_client()
|
411 |
|
412 |
async def finalize(self):
|
413 |
+
if self.db is not None:
|
414 |
await ClientManager.release_client(self.db)
|
415 |
self.db = None
|
416 |
|
|
|
450 |
@final
|
451 |
@dataclass
|
452 |
class OracleGraphStorage(BaseGraphStorage):
|
453 |
+
db: OracleDB = field(init=False)
|
454 |
+
|
455 |
def __post_init__(self):
|
456 |
self._max_batch_size = self.global_config.get("embedding_batch_num", 10)
|
457 |
|
458 |
async def initialize(self):
|
459 |
+
if self.db is None:
|
460 |
self.db = await ClientManager.get_client()
|
461 |
|
462 |
async def finalize(self):
|
463 |
+
if self.db is not None:
|
464 |
await ClientManager.release_client(self.db)
|
465 |
self.db = None
|
466 |
|
lightrag/kg/postgres_impl.py
CHANGED
@@ -3,7 +3,7 @@ import inspect
|
|
3 |
import json
|
4 |
import os
|
5 |
import time
|
6 |
-
from dataclasses import dataclass
|
7 |
from typing import Any, Dict, List, Union, final
|
8 |
import numpy as np
|
9 |
import configparser
|
@@ -246,18 +246,17 @@ class ClientManager:
|
|
246 |
@final
|
247 |
@dataclass
|
248 |
class PGKVStorage(BaseKVStorage):
|
249 |
-
|
250 |
-
# db: PostgreSQLDB
|
251 |
|
252 |
def __post_init__(self):
|
253 |
self._max_batch_size = self.global_config["embedding_batch_num"]
|
254 |
|
255 |
async def initialize(self):
|
256 |
-
if
|
257 |
self.db = await ClientManager.get_client()
|
258 |
|
259 |
async def finalize(self):
|
260 |
-
if
|
261 |
await ClientManager.release_client(self.db)
|
262 |
self.db = None
|
263 |
|
@@ -379,6 +378,8 @@ class PGKVStorage(BaseKVStorage):
|
|
379 |
@final
|
380 |
@dataclass
|
381 |
class PGVectorStorage(BaseVectorStorage):
|
|
|
|
|
382 |
def __post_init__(self):
|
383 |
self._max_batch_size = self.global_config["embedding_batch_num"]
|
384 |
config = self.global_config.get("vector_db_storage_cls_kwargs", {})
|
@@ -390,11 +391,11 @@ class PGVectorStorage(BaseVectorStorage):
|
|
390 |
self.cosine_better_than_threshold = cosine_threshold
|
391 |
|
392 |
async def initialize(self):
|
393 |
-
if
|
394 |
self.db = await ClientManager.get_client()
|
395 |
|
396 |
async def finalize(self):
|
397 |
-
if
|
398 |
await ClientManager.release_client(self.db)
|
399 |
self.db = None
|
400 |
|
@@ -514,12 +515,14 @@ class PGVectorStorage(BaseVectorStorage):
|
|
514 |
@final
|
515 |
@dataclass
|
516 |
class PGDocStatusStorage(DocStatusStorage):
|
|
|
|
|
517 |
async def initialize(self):
|
518 |
-
if
|
519 |
self.db = await ClientManager.get_client()
|
520 |
|
521 |
async def finalize(self):
|
522 |
-
if
|
523 |
await ClientManager.release_client(self.db)
|
524 |
self.db = None
|
525 |
|
@@ -662,6 +665,8 @@ class PGGraphQueryException(Exception):
|
|
662 |
@final
|
663 |
@dataclass
|
664 |
class PGGraphStorage(BaseGraphStorage):
|
|
|
|
|
665 |
@staticmethod
|
666 |
def load_nx_graph(file_name):
|
667 |
print("no preloading of graph with AGE in production")
|
@@ -673,11 +678,11 @@ class PGGraphStorage(BaseGraphStorage):
|
|
673 |
}
|
674 |
|
675 |
async def initialize(self):
|
676 |
-
if
|
677 |
self.db = await ClientManager.get_client()
|
678 |
|
679 |
async def finalize(self):
|
680 |
-
if
|
681 |
await ClientManager.release_client(self.db)
|
682 |
self.db = None
|
683 |
|
|
|
3 |
import json
|
4 |
import os
|
5 |
import time
|
6 |
+
from dataclasses import dataclass, field
|
7 |
from typing import Any, Dict, List, Union, final
|
8 |
import numpy as np
|
9 |
import configparser
|
|
|
246 |
@final
|
247 |
@dataclass
|
248 |
class PGKVStorage(BaseKVStorage):
|
249 |
+
db: PostgreSQLDB = field(init=False)
|
|
|
250 |
|
251 |
def __post_init__(self):
|
252 |
self._max_batch_size = self.global_config["embedding_batch_num"]
|
253 |
|
254 |
async def initialize(self):
|
255 |
+
if self.db is None:
|
256 |
self.db = await ClientManager.get_client()
|
257 |
|
258 |
async def finalize(self):
|
259 |
+
if self.db is not None:
|
260 |
await ClientManager.release_client(self.db)
|
261 |
self.db = None
|
262 |
|
|
|
378 |
@final
|
379 |
@dataclass
|
380 |
class PGVectorStorage(BaseVectorStorage):
|
381 |
+
db: PostgreSQLDB = field(init=False)
|
382 |
+
|
383 |
def __post_init__(self):
|
384 |
self._max_batch_size = self.global_config["embedding_batch_num"]
|
385 |
config = self.global_config.get("vector_db_storage_cls_kwargs", {})
|
|
|
391 |
self.cosine_better_than_threshold = cosine_threshold
|
392 |
|
393 |
async def initialize(self):
|
394 |
+
if self.db is None:
|
395 |
self.db = await ClientManager.get_client()
|
396 |
|
397 |
async def finalize(self):
|
398 |
+
if self.db is not None:
|
399 |
await ClientManager.release_client(self.db)
|
400 |
self.db = None
|
401 |
|
|
|
515 |
@final
|
516 |
@dataclass
|
517 |
class PGDocStatusStorage(DocStatusStorage):
|
518 |
+
db: PostgreSQLDB = field(init=False)
|
519 |
+
|
520 |
async def initialize(self):
|
521 |
+
if self.db is None:
|
522 |
self.db = await ClientManager.get_client()
|
523 |
|
524 |
async def finalize(self):
|
525 |
+
if self.db is not None:
|
526 |
await ClientManager.release_client(self.db)
|
527 |
self.db = None
|
528 |
|
|
|
665 |
@final
|
666 |
@dataclass
|
667 |
class PGGraphStorage(BaseGraphStorage):
|
668 |
+
db: PostgreSQLDB = field(init=False)
|
669 |
+
|
670 |
@staticmethod
|
671 |
def load_nx_graph(file_name):
|
672 |
print("no preloading of graph with AGE in production")
|
|
|
678 |
}
|
679 |
|
680 |
async def initialize(self):
|
681 |
+
if self.db is None:
|
682 |
self.db = await ClientManager.get_client()
|
683 |
|
684 |
async def finalize(self):
|
685 |
+
if self.db is not None:
|
686 |
await ClientManager.release_client(self.db)
|
687 |
self.db = None
|
688 |
|
lightrag/kg/tidb_impl.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
import asyncio
|
2 |
import os
|
3 |
-
from dataclasses import dataclass
|
4 |
from typing import Any, Union, final
|
5 |
|
6 |
import numpy as np
|
@@ -166,19 +166,18 @@ class ClientManager:
|
|
166 |
@final
|
167 |
@dataclass
|
168 |
class TiDBKVStorage(BaseKVStorage):
|
169 |
-
|
170 |
-
# db: TiDB
|
171 |
|
172 |
def __post_init__(self):
|
173 |
self._data = {}
|
174 |
self._max_batch_size = self.global_config["embedding_batch_num"]
|
175 |
|
176 |
async def initialize(self):
|
177 |
-
if
|
178 |
self.db = await ClientManager.get_client()
|
179 |
|
180 |
async def finalize(self):
|
181 |
-
if
|
182 |
await ClientManager.release_client(self.db)
|
183 |
self.db = None
|
184 |
|
@@ -280,6 +279,8 @@ class TiDBKVStorage(BaseKVStorage):
|
|
280 |
@final
|
281 |
@dataclass
|
282 |
class TiDBVectorDBStorage(BaseVectorStorage):
|
|
|
|
|
283 |
def __post_init__(self):
|
284 |
self._client_file_name = os.path.join(
|
285 |
self.global_config["working_dir"], f"vdb_{self.namespace}.json"
|
@@ -294,11 +295,11 @@ class TiDBVectorDBStorage(BaseVectorStorage):
|
|
294 |
self.cosine_better_than_threshold = cosine_threshold
|
295 |
|
296 |
async def initialize(self):
|
297 |
-
if
|
298 |
self.db = await ClientManager.get_client()
|
299 |
|
300 |
async def finalize(self):
|
301 |
-
if
|
302 |
await ClientManager.release_client(self.db)
|
303 |
self.db = None
|
304 |
|
@@ -421,18 +422,17 @@ class TiDBVectorDBStorage(BaseVectorStorage):
|
|
421 |
@final
|
422 |
@dataclass
|
423 |
class TiDBGraphStorage(BaseGraphStorage):
|
424 |
-
|
425 |
-
# db: TiDB
|
426 |
|
427 |
def __post_init__(self):
|
428 |
self._max_batch_size = self.global_config["embedding_batch_num"]
|
429 |
|
430 |
async def initialize(self):
|
431 |
-
if
|
432 |
self.db = await ClientManager.get_client()
|
433 |
|
434 |
async def finalize(self):
|
435 |
-
if
|
436 |
await ClientManager.release_client(self.db)
|
437 |
self.db = None
|
438 |
|
|
|
1 |
import asyncio
|
2 |
import os
|
3 |
+
from dataclasses import dataclass, field
|
4 |
from typing import Any, Union, final
|
5 |
|
6 |
import numpy as np
|
|
|
166 |
@final
|
167 |
@dataclass
|
168 |
class TiDBKVStorage(BaseKVStorage):
|
169 |
+
db: TiDB = field(init=False)
|
|
|
170 |
|
171 |
def __post_init__(self):
|
172 |
self._data = {}
|
173 |
self._max_batch_size = self.global_config["embedding_batch_num"]
|
174 |
|
175 |
async def initialize(self):
|
176 |
+
if self.db is None:
|
177 |
self.db = await ClientManager.get_client()
|
178 |
|
179 |
async def finalize(self):
|
180 |
+
if self.db is not None:
|
181 |
await ClientManager.release_client(self.db)
|
182 |
self.db = None
|
183 |
|
|
|
279 |
@final
|
280 |
@dataclass
|
281 |
class TiDBVectorDBStorage(BaseVectorStorage):
|
282 |
+
db: TiDB = field(init=False)
|
283 |
+
|
284 |
def __post_init__(self):
|
285 |
self._client_file_name = os.path.join(
|
286 |
self.global_config["working_dir"], f"vdb_{self.namespace}.json"
|
|
|
295 |
self.cosine_better_than_threshold = cosine_threshold
|
296 |
|
297 |
async def initialize(self):
|
298 |
+
if self.db is None:
|
299 |
self.db = await ClientManager.get_client()
|
300 |
|
301 |
async def finalize(self):
|
302 |
+
if self.db is not None:
|
303 |
await ClientManager.release_client(self.db)
|
304 |
self.db = None
|
305 |
|
|
|
422 |
@final
|
423 |
@dataclass
|
424 |
class TiDBGraphStorage(BaseGraphStorage):
|
425 |
+
db: TiDB = field(init=False)
|
|
|
426 |
|
427 |
def __post_init__(self):
|
428 |
self._max_batch_size = self.global_config["embedding_batch_num"]
|
429 |
|
430 |
async def initialize(self):
|
431 |
+
if self.db is None:
|
432 |
self.db = await ClientManager.get_client()
|
433 |
|
434 |
async def finalize(self):
|
435 |
+
if self.db is not None:
|
436 |
await ClientManager.release_client(self.db)
|
437 |
self.db = None
|
438 |
|