yangdx commited on
Commit
f8a75bc
·
1 Parent(s): 30d317a

Inject Postgres to LightRag storage class when needed

Browse files
Files changed (1) hide show
  1. lightrag/lightrag.py +68 -12
lightrag/lightrag.py CHANGED
@@ -17,6 +17,7 @@ from .base import (
17
  StorageNameSpace,
18
  )
19
  from .kg.oracle_impl import OracleDB
 
20
  from .namespace import NameSpace, make_namespace
21
  from .operate import (
22
  chunking_by_token_size,
@@ -394,6 +395,18 @@ class LightRAG:
394
  self.graph_storage_cls, global_config=global_config
395
  )
396
 
 
 
 
 
 
 
 
 
 
 
 
 
397
  # 检查是否使用了 Oracle 存储实现
398
  if (
399
  self.kv_storage == "OracleKVStorage"
@@ -403,14 +416,16 @@ class LightRAG:
403
  # 从环境变量或配置文件获取参数
404
  dbconfig = {
405
  "user": os.environ.get(
406
- "ORACLE_USER", config.get("oracle", "user", fallback=None)
 
407
  ),
408
  "password": os.environ.get(
409
  "ORACLE_PASSWORD",
410
  config.get("oracle", "password", fallback=None),
411
  ),
412
  "dsn": os.environ.get(
413
- "ORACLE_DSN", config.get("oracle", "dsn", fallback=None)
 
414
  ),
415
  "config_dir": os.environ.get(
416
  "ORACLE_CONFIG_DIR",
@@ -441,17 +456,58 @@ class LightRAG:
441
  if self.graph_storage == "OracleGraphStorage":
442
  self.graph_storage_cls.db = oracle_db
443
 
444
- self.json_doc_status_storage = self.key_string_value_json_storage_cls(
445
- namespace=self.namespace_prefix + "json_doc_status_storage",
446
- embedding_func=None,
447
- )
 
 
 
 
 
 
 
448
 
449
- self.llm_response_cache = self.key_string_value_json_storage_cls(
450
- namespace=make_namespace(
451
- self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE
452
- ),
453
- embedding_func=self.embedding_func,
454
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
455
 
456
  ####
457
  # add embedding func by walter
 
17
  StorageNameSpace,
18
  )
19
  from .kg.oracle_impl import OracleDB
20
+ from .kg.postgres_impl import PostgreSQLDB
21
  from .namespace import NameSpace, make_namespace
22
  from .operate import (
23
  chunking_by_token_size,
 
395
  self.graph_storage_cls, global_config=global_config
396
  )
397
 
398
+ self.json_doc_status_storage = self.key_string_value_json_storage_cls(
399
+ namespace=self.namespace_prefix + "json_doc_status_storage",
400
+ embedding_func=None,
401
+ )
402
+
403
+ self.llm_response_cache = self.key_string_value_json_storage_cls(
404
+ namespace=make_namespace(
405
+ self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE
406
+ ),
407
+ embedding_func=self.embedding_func,
408
+ )
409
+
410
  # 检查是否使用了 Oracle 存储实现
411
  if (
412
  self.kv_storage == "OracleKVStorage"
 
416
  # 从环境变量或配置文件获取参数
417
  dbconfig = {
418
  "user": os.environ.get(
419
+ "ORACLE_USER",
420
+ config.get("oracle", "user", fallback=None),
421
  ),
422
  "password": os.environ.get(
423
  "ORACLE_PASSWORD",
424
  config.get("oracle", "password", fallback=None),
425
  ),
426
  "dsn": os.environ.get(
427
+ "ORACLE_DSN",
428
+ config.get("oracle", "dsn", fallback=None),
429
  ),
430
  "config_dir": os.environ.get(
431
  "ORACLE_CONFIG_DIR",
 
456
  if self.graph_storage == "OracleGraphStorage":
457
  self.graph_storage_cls.db = oracle_db
458
 
459
+ # 检查是否使用了 PostgreSQL 存储实现
460
+ if (
461
+ self.kv_storage == "PGKVStorage"
462
+ or self.vector_storage == "PGVectorStorage"
463
+ or self.graph_storage == "PGGraphStorage"
464
+ or self.json_doc_status_storage == "PGDocStatusStorage"
465
+ ):
466
+ # 读取配置文件
467
+ config_parser = configparser.ConfigParser()
468
+ if os.path.exists("config.ini"):
469
+ config_parser.read("config.ini")
470
 
471
+ # 从环境变量或配置文件获取参数
472
+ dbconfig = {
473
+ "host": os.environ.get(
474
+ "POSTGRES_HOST",
475
+ config.get("postgres", "host", fallback="localhost"),
476
+ ),
477
+ "port": os.environ.get(
478
+ "POSTGRES_PORT", config.get("postgres", "port", fallback=5432)
479
+ ),
480
+ "user": os.environ.get(
481
+ "POSTGRES_USER", config.get("postgres", "user", fallback=None)
482
+ ),
483
+ "password": os.environ.get(
484
+ "POSTGRES_PASSWORD",
485
+ config.get("postgres", "password", fallback=None),
486
+ ),
487
+ "database": os.environ.get(
488
+ "POSTGRES_DATABASE",
489
+ config.get("postgres", "database", fallback=None),
490
+ ),
491
+ "workspace": os.environ.get(
492
+ "POSTGRES_WORKSPACE",
493
+ config.get("postgres", "workspace", fallback="default"),
494
+ ),
495
+ }
496
+
497
+ # 初始化 PostgreSQLDB 对象
498
+ postgres_db = PostgreSQLDB(dbconfig)
499
+ loop = always_get_an_event_loop()
500
+ loop.run_until_complete(postgres_db.initdb())
501
+
502
+ # 只对 PostgreSQL 实现的存储类注入 db 对象
503
+ if self.kv_storage == "PGKVStorage":
504
+ self.key_string_value_json_storage_cls.db = postgres_db
505
+ if self.vector_storage == "PGVectorStorage":
506
+ self.vector_db_storage_cls.db = postgres_db
507
+ if self.graph_storage == "PGGraphStorage":
508
+ self.graph_storage_cls.db = postgres_db
509
+ if self.json_doc_status_storage == "OracleGraphStorage":
510
+ self.json_doc_status_storage = postgres_db
511
 
512
  ####
513
  # add embedding func by walter