yangdx
commited on
Commit
·
f8a75bc
1
Parent(s):
30d317a
Inject Postgres to LightRag storage class when needed
Browse files- lightrag/lightrag.py +68 -12
lightrag/lightrag.py
CHANGED
@@ -17,6 +17,7 @@ from .base import (
|
|
17 |
StorageNameSpace,
|
18 |
)
|
19 |
from .kg.oracle_impl import OracleDB
|
|
|
20 |
from .namespace import NameSpace, make_namespace
|
21 |
from .operate import (
|
22 |
chunking_by_token_size,
|
@@ -394,6 +395,18 @@ class LightRAG:
|
|
394 |
self.graph_storage_cls, global_config=global_config
|
395 |
)
|
396 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
397 |
# 检查是否使用了 Oracle 存储实现
|
398 |
if (
|
399 |
self.kv_storage == "OracleKVStorage"
|
@@ -403,14 +416,16 @@ class LightRAG:
|
|
403 |
# 从环境变量或配置文件获取参数
|
404 |
dbconfig = {
|
405 |
"user": os.environ.get(
|
406 |
-
"ORACLE_USER",
|
|
|
407 |
),
|
408 |
"password": os.environ.get(
|
409 |
"ORACLE_PASSWORD",
|
410 |
config.get("oracle", "password", fallback=None),
|
411 |
),
|
412 |
"dsn": os.environ.get(
|
413 |
-
"ORACLE_DSN",
|
|
|
414 |
),
|
415 |
"config_dir": os.environ.get(
|
416 |
"ORACLE_CONFIG_DIR",
|
@@ -441,17 +456,58 @@ class LightRAG:
|
|
441 |
if self.graph_storage == "OracleGraphStorage":
|
442 |
self.graph_storage_cls.db = oracle_db
|
443 |
|
444 |
-
|
445 |
-
|
446 |
-
|
447 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
448 |
|
449 |
-
|
450 |
-
|
451 |
-
|
452 |
-
|
453 |
-
|
454 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
455 |
|
456 |
####
|
457 |
# add embedding func by walter
|
|
|
17 |
StorageNameSpace,
|
18 |
)
|
19 |
from .kg.oracle_impl import OracleDB
|
20 |
+
from .kg.postgres_impl import PostgreSQLDB
|
21 |
from .namespace import NameSpace, make_namespace
|
22 |
from .operate import (
|
23 |
chunking_by_token_size,
|
|
|
395 |
self.graph_storage_cls, global_config=global_config
|
396 |
)
|
397 |
|
398 |
+
self.json_doc_status_storage = self.key_string_value_json_storage_cls(
|
399 |
+
namespace=self.namespace_prefix + "json_doc_status_storage",
|
400 |
+
embedding_func=None,
|
401 |
+
)
|
402 |
+
|
403 |
+
self.llm_response_cache = self.key_string_value_json_storage_cls(
|
404 |
+
namespace=make_namespace(
|
405 |
+
self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE
|
406 |
+
),
|
407 |
+
embedding_func=self.embedding_func,
|
408 |
+
)
|
409 |
+
|
410 |
# 检查是否使用了 Oracle 存储实现
|
411 |
if (
|
412 |
self.kv_storage == "OracleKVStorage"
|
|
|
416 |
# 从环境变量或配置文件获取参数
|
417 |
dbconfig = {
|
418 |
"user": os.environ.get(
|
419 |
+
"ORACLE_USER",
|
420 |
+
config.get("oracle", "user", fallback=None),
|
421 |
),
|
422 |
"password": os.environ.get(
|
423 |
"ORACLE_PASSWORD",
|
424 |
config.get("oracle", "password", fallback=None),
|
425 |
),
|
426 |
"dsn": os.environ.get(
|
427 |
+
"ORACLE_DSN",
|
428 |
+
config.get("oracle", "dsn", fallback=None),
|
429 |
),
|
430 |
"config_dir": os.environ.get(
|
431 |
"ORACLE_CONFIG_DIR",
|
|
|
456 |
if self.graph_storage == "OracleGraphStorage":
|
457 |
self.graph_storage_cls.db = oracle_db
|
458 |
|
459 |
+
# 检查是否使用了 PostgreSQL 存储实现
|
460 |
+
if (
|
461 |
+
self.kv_storage == "PGKVStorage"
|
462 |
+
or self.vector_storage == "PGVectorStorage"
|
463 |
+
or self.graph_storage == "PGGraphStorage"
|
464 |
+
or self.json_doc_status_storage == "PGDocStatusStorage"
|
465 |
+
):
|
466 |
+
# 读取配置文件
|
467 |
+
config_parser = configparser.ConfigParser()
|
468 |
+
if os.path.exists("config.ini"):
|
469 |
+
config_parser.read("config.ini")
|
470 |
|
471 |
+
# 从环境变量或配置文件获取参数
|
472 |
+
dbconfig = {
|
473 |
+
"host": os.environ.get(
|
474 |
+
"POSTGRES_HOST",
|
475 |
+
config.get("postgres", "host", fallback="localhost"),
|
476 |
+
),
|
477 |
+
"port": os.environ.get(
|
478 |
+
"POSTGRES_PORT", config.get("postgres", "port", fallback=5432)
|
479 |
+
),
|
480 |
+
"user": os.environ.get(
|
481 |
+
"POSTGRES_USER", config.get("postgres", "user", fallback=None)
|
482 |
+
),
|
483 |
+
"password": os.environ.get(
|
484 |
+
"POSTGRES_PASSWORD",
|
485 |
+
config.get("postgres", "password", fallback=None),
|
486 |
+
),
|
487 |
+
"database": os.environ.get(
|
488 |
+
"POSTGRES_DATABASE",
|
489 |
+
config.get("postgres", "database", fallback=None),
|
490 |
+
),
|
491 |
+
"workspace": os.environ.get(
|
492 |
+
"POSTGRES_WORKSPACE",
|
493 |
+
config.get("postgres", "workspace", fallback="default"),
|
494 |
+
),
|
495 |
+
}
|
496 |
+
|
497 |
+
# 初始化 PostgreSQLDB 对象
|
498 |
+
postgres_db = PostgreSQLDB(dbconfig)
|
499 |
+
loop = always_get_an_event_loop()
|
500 |
+
loop.run_until_complete(postgres_db.initdb())
|
501 |
+
|
502 |
+
# 只对 PostgreSQL 实现的存储类注入 db 对象
|
503 |
+
if self.kv_storage == "PGKVStorage":
|
504 |
+
self.key_string_value_json_storage_cls.db = postgres_db
|
505 |
+
if self.vector_storage == "PGVectorStorage":
|
506 |
+
self.vector_db_storage_cls.db = postgres_db
|
507 |
+
if self.graph_storage == "PGGraphStorage":
|
508 |
+
self.graph_storage_cls.db = postgres_db
|
509 |
+
if self.json_doc_status_storage == "OracleGraphStorage":
|
510 |
+
self.json_doc_status_storage = postgres_db
|
511 |
|
512 |
####
|
513 |
# add embedding func by walter
|