yangdx
commited on
Commit
·
1fce92c
1
Parent(s):
3c340f7
Inject TiDB同LightRAG storage when needed
Browse files- lightrag/kg/tidb_impl.py +7 -0
- lightrag/lightrag.py +51 -2
lightrag/kg/tidb_impl.py
CHANGED
@@ -102,6 +102,8 @@ class TiDB:
|
|
102 |
@dataclass
|
103 |
class TiDBKVStorage(BaseKVStorage):
|
104 |
# should pass db object to self.db
|
|
|
|
|
105 |
def __post_init__(self):
|
106 |
self._data = {}
|
107 |
self._max_batch_size = self.global_config["embedding_batch_num"]
|
@@ -208,6 +210,8 @@ class TiDBKVStorage(BaseKVStorage):
|
|
208 |
|
209 |
@dataclass
|
210 |
class TiDBVectorDBStorage(BaseVectorStorage):
|
|
|
|
|
211 |
cosine_better_than_threshold: float = float(os.getenv("COSINE_THRESHOLD", "0.2"))
|
212 |
|
213 |
def __post_init__(self):
|
@@ -329,6 +333,9 @@ class TiDBVectorDBStorage(BaseVectorStorage):
|
|
329 |
|
330 |
@dataclass
|
331 |
class TiDBGraphStorage(BaseGraphStorage):
|
|
|
|
|
|
|
332 |
def __post_init__(self):
|
333 |
self._max_batch_size = self.global_config["embedding_batch_num"]
|
334 |
|
|
|
102 |
@dataclass
|
103 |
class TiDBKVStorage(BaseKVStorage):
|
104 |
# should pass db object to self.db
|
105 |
+
db: TiDB = None
|
106 |
+
|
107 |
def __post_init__(self):
|
108 |
self._data = {}
|
109 |
self._max_batch_size = self.global_config["embedding_batch_num"]
|
|
|
210 |
|
211 |
@dataclass
|
212 |
class TiDBVectorDBStorage(BaseVectorStorage):
|
213 |
+
# should pass db object to self.db
|
214 |
+
db: TiDB = None
|
215 |
cosine_better_than_threshold: float = float(os.getenv("COSINE_THRESHOLD", "0.2"))
|
216 |
|
217 |
def __post_init__(self):
|
|
|
333 |
|
334 |
@dataclass
|
335 |
class TiDBGraphStorage(BaseGraphStorage):
|
336 |
+
# should pass db object to self.db
|
337 |
+
db: TiDB = None
|
338 |
+
|
339 |
def __post_init__(self):
|
340 |
self._max_batch_size = self.global_config["embedding_batch_num"]
|
341 |
|
lightrag/lightrag.py
CHANGED
@@ -16,8 +16,6 @@ from .base import (
|
|
16 |
QueryParam,
|
17 |
StorageNameSpace,
|
18 |
)
|
19 |
-
from .kg.oracle_impl import OracleDB
|
20 |
-
from .kg.postgres_impl import PostgreSQLDB
|
21 |
from .namespace import NameSpace, make_namespace
|
22 |
from .operate import (
|
23 |
chunking_by_token_size,
|
@@ -446,6 +444,7 @@ class LightRAG:
|
|
446 |
}
|
447 |
|
448 |
# 初始化 OracleDB 对象
|
|
|
449 |
oracle_db = OracleDB(dbconfig)
|
450 |
# Check if DB tables exist, if not, tables will be created
|
451 |
loop = always_get_an_event_loop()
|
@@ -459,6 +458,55 @@ class LightRAG:
|
|
459 |
if self.graph_storage == "OracleGraphStorage":
|
460 |
self.graph_storage_cls.db = oracle_db
|
461 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
462 |
# 检查是否使用了 PostgreSQL 存储实现
|
463 |
if (
|
464 |
self.kv_storage == "PGKVStorage"
|
@@ -498,6 +546,7 @@ class LightRAG:
|
|
498 |
}
|
499 |
|
500 |
# 初始化 PostgreSQLDB 对象
|
|
|
501 |
postgres_db = PostgreSQLDB(dbconfig)
|
502 |
# Initialize and check tables
|
503 |
loop = always_get_an_event_loop()
|
|
|
16 |
QueryParam,
|
17 |
StorageNameSpace,
|
18 |
)
|
|
|
|
|
19 |
from .namespace import NameSpace, make_namespace
|
20 |
from .operate import (
|
21 |
chunking_by_token_size,
|
|
|
444 |
}
|
445 |
|
446 |
# 初始化 OracleDB 对象
|
447 |
+
from .kg.oracle_impl import OracleDB
|
448 |
oracle_db = OracleDB(dbconfig)
|
449 |
# Check if DB tables exist, if not, tables will be created
|
450 |
loop = always_get_an_event_loop()
|
|
|
458 |
if self.graph_storage == "OracleGraphStorage":
|
459 |
self.graph_storage_cls.db = oracle_db
|
460 |
|
461 |
+
# 检查是否使用了 TiDB 存储实现
|
462 |
+
if (
|
463 |
+
self.kv_storage == "TiDBKVStorage"
|
464 |
+
or self.vector_storage == "TiDBVectorDBStorage"
|
465 |
+
or self.graph_storage == "TiDBGraphStorage"
|
466 |
+
):
|
467 |
+
# 从环境变量或配置文件获取参数
|
468 |
+
dbconfig = {
|
469 |
+
"host": os.environ.get(
|
470 |
+
"TIDB_HOST",
|
471 |
+
config.get("tidb", "host", fallback="localhost"),
|
472 |
+
),
|
473 |
+
"port": os.environ.get(
|
474 |
+
"TIDB_PORT",
|
475 |
+
config.get("tidb", "port", fallback=4000)
|
476 |
+
),
|
477 |
+
"user": os.environ.get(
|
478 |
+
"TIDB_USER",
|
479 |
+
config.get("tidb", "user", fallback=None),
|
480 |
+
),
|
481 |
+
"password": os.environ.get(
|
482 |
+
"TIDB_PASSWORD",
|
483 |
+
config.get("tidb", "password", fallback=None),
|
484 |
+
),
|
485 |
+
"database": os.environ.get(
|
486 |
+
"TIDB_DATABASE",
|
487 |
+
config.get("tidb", "database", fallback=None),
|
488 |
+
),
|
489 |
+
"workspace": os.environ.get(
|
490 |
+
"TIDB_WORKSPACE",
|
491 |
+
config.get("tidb", "workspace", fallback="default"),
|
492 |
+
),
|
493 |
+
}
|
494 |
+
|
495 |
+
# 初始化 TiDB 对象
|
496 |
+
from .kg.tidb_impl import TiDB
|
497 |
+
tidb_db = TiDB(dbconfig)
|
498 |
+
# Check if DB tables exist, if not, tables will be created
|
499 |
+
loop = always_get_an_event_loop()
|
500 |
+
loop.run_until_complete(tidb_db.check_tables())
|
501 |
+
|
502 |
+
# 只对 TiDB 实现的存储类注入 db 对象
|
503 |
+
if self.kv_storage == "TiDBKVStorage":
|
504 |
+
self.key_string_value_json_storage_cls.db = tidb_db
|
505 |
+
if self.vector_storage == "TiDBVectorDBStorage":
|
506 |
+
self.vector_db_storage_cls.db = tidb_db
|
507 |
+
if self.graph_storage == "TiDBGraphStorage":
|
508 |
+
self.graph_storage_cls.db = tidb_db
|
509 |
+
|
510 |
# 检查是否使用了 PostgreSQL 存储实现
|
511 |
if (
|
512 |
self.kv_storage == "PGKVStorage"
|
|
|
546 |
}
|
547 |
|
548 |
# 初始化 PostgreSQLDB 对象
|
549 |
+
from .kg.postgres_impl import PostgreSQLDB
|
550 |
postgres_db = PostgreSQLDB(dbconfig)
|
551 |
# Initialize and check tables
|
552 |
loop = always_get_an_event_loop()
|