yangdx commited on
Commit
1fce92c
·
1 Parent(s): 3c340f7

Inject TiDB同LightRAG storage when needed

Browse files
Files changed (2) hide show
  1. lightrag/kg/tidb_impl.py +7 -0
  2. lightrag/lightrag.py +51 -2
lightrag/kg/tidb_impl.py CHANGED
@@ -102,6 +102,8 @@ class TiDB:
102
  @dataclass
103
  class TiDBKVStorage(BaseKVStorage):
104
  # should pass db object to self.db
 
 
105
  def __post_init__(self):
106
  self._data = {}
107
  self._max_batch_size = self.global_config["embedding_batch_num"]
@@ -208,6 +210,8 @@ class TiDBKVStorage(BaseKVStorage):
208
 
209
  @dataclass
210
  class TiDBVectorDBStorage(BaseVectorStorage):
 
 
211
  cosine_better_than_threshold: float = float(os.getenv("COSINE_THRESHOLD", "0.2"))
212
 
213
  def __post_init__(self):
@@ -329,6 +333,9 @@ class TiDBVectorDBStorage(BaseVectorStorage):
329
 
330
  @dataclass
331
  class TiDBGraphStorage(BaseGraphStorage):
 
 
 
332
  def __post_init__(self):
333
  self._max_batch_size = self.global_config["embedding_batch_num"]
334
 
 
102
  @dataclass
103
  class TiDBKVStorage(BaseKVStorage):
104
  # should pass db object to self.db
105
+ db: TiDB = None
106
+
107
  def __post_init__(self):
108
  self._data = {}
109
  self._max_batch_size = self.global_config["embedding_batch_num"]
 
210
 
211
  @dataclass
212
  class TiDBVectorDBStorage(BaseVectorStorage):
213
+ # should pass db object to self.db
214
+ db: TiDB = None
215
  cosine_better_than_threshold: float = float(os.getenv("COSINE_THRESHOLD", "0.2"))
216
 
217
  def __post_init__(self):
 
333
 
334
  @dataclass
335
  class TiDBGraphStorage(BaseGraphStorage):
336
+ # should pass db object to self.db
337
+ db: TiDB = None
338
+
339
  def __post_init__(self):
340
  self._max_batch_size = self.global_config["embedding_batch_num"]
341
 
lightrag/lightrag.py CHANGED
@@ -16,8 +16,6 @@ from .base import (
16
  QueryParam,
17
  StorageNameSpace,
18
  )
19
- from .kg.oracle_impl import OracleDB
20
- from .kg.postgres_impl import PostgreSQLDB
21
  from .namespace import NameSpace, make_namespace
22
  from .operate import (
23
  chunking_by_token_size,
@@ -446,6 +444,7 @@ class LightRAG:
446
  }
447
 
448
  # 初始化 OracleDB 对象
 
449
  oracle_db = OracleDB(dbconfig)
450
  # Check if DB tables exist, if not, tables will be created
451
  loop = always_get_an_event_loop()
@@ -459,6 +458,55 @@ class LightRAG:
459
  if self.graph_storage == "OracleGraphStorage":
460
  self.graph_storage_cls.db = oracle_db
461
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
462
  # 检查是否使用了 PostgreSQL 存储实现
463
  if (
464
  self.kv_storage == "PGKVStorage"
@@ -498,6 +546,7 @@ class LightRAG:
498
  }
499
 
500
  # 初始化 PostgreSQLDB 对象
 
501
  postgres_db = PostgreSQLDB(dbconfig)
502
  # Initialize and check tables
503
  loop = always_get_an_event_loop()
 
16
  QueryParam,
17
  StorageNameSpace,
18
  )
 
 
19
  from .namespace import NameSpace, make_namespace
20
  from .operate import (
21
  chunking_by_token_size,
 
444
  }
445
 
446
  # 初始化 OracleDB 对象
447
+ from .kg.oracle_impl import OracleDB
448
  oracle_db = OracleDB(dbconfig)
449
  # Check if DB tables exist, if not, tables will be created
450
  loop = always_get_an_event_loop()
 
458
  if self.graph_storage == "OracleGraphStorage":
459
  self.graph_storage_cls.db = oracle_db
460
 
461
+ # 检查是否使用了 TiDB 存储实现
462
+ if (
463
+ self.kv_storage == "TiDBKVStorage"
464
+ or self.vector_storage == "TiDBVectorDBStorage"
465
+ or self.graph_storage == "TiDBGraphStorage"
466
+ ):
467
+ # 从环境变量或配置文件获取参数
468
+ dbconfig = {
469
+ "host": os.environ.get(
470
+ "TIDB_HOST",
471
+ config.get("tidb", "host", fallback="localhost"),
472
+ ),
473
+ "port": os.environ.get(
474
+ "TIDB_PORT",
475
+ config.get("tidb", "port", fallback=4000)
476
+ ),
477
+ "user": os.environ.get(
478
+ "TIDB_USER",
479
+ config.get("tidb", "user", fallback=None),
480
+ ),
481
+ "password": os.environ.get(
482
+ "TIDB_PASSWORD",
483
+ config.get("tidb", "password", fallback=None),
484
+ ),
485
+ "database": os.environ.get(
486
+ "TIDB_DATABASE",
487
+ config.get("tidb", "database", fallback=None),
488
+ ),
489
+ "workspace": os.environ.get(
490
+ "TIDB_WORKSPACE",
491
+ config.get("tidb", "workspace", fallback="default"),
492
+ ),
493
+ }
494
+
495
+ # 初始化 TiDB 对象
496
+ from .kg.tidb_impl import TiDB
497
+ tidb_db = TiDB(dbconfig)
498
+ # Check if DB tables exist, if not, tables will be created
499
+ loop = always_get_an_event_loop()
500
+ loop.run_until_complete(tidb_db.check_tables())
501
+
502
+ # 只对 TiDB 实现的存储类注入 db 对象
503
+ if self.kv_storage == "TiDBKVStorage":
504
+ self.key_string_value_json_storage_cls.db = tidb_db
505
+ if self.vector_storage == "TiDBVectorDBStorage":
506
+ self.vector_db_storage_cls.db = tidb_db
507
+ if self.graph_storage == "TiDBGraphStorage":
508
+ self.graph_storage_cls.db = tidb_db
509
+
510
  # 检查是否使用了 PostgreSQL 存储实现
511
  if (
512
  self.kv_storage == "PGKVStorage"
 
546
  }
547
 
548
  # 初始化 PostgreSQLDB 对象
549
+ from .kg.postgres_impl import PostgreSQLDB
550
  postgres_db = PostgreSQLDB(dbconfig)
551
  # Initialize and check tables
552
  loop = always_get_an_event_loop()