yangdx commited on
Commit
30d317a
·
1 Parent(s): f224c7f

Inject oracle db to LightRag storage class when needed

Browse files
Files changed (2) hide show
  1. lightrag/kg/oracle_impl.py +2 -1
  2. lightrag/lightrag.py +52 -0
lightrag/kg/oracle_impl.py CHANGED
@@ -361,7 +361,8 @@ class OracleVectorDBStorage(BaseVectorStorage):
361
 
362
  @dataclass
363
  class OracleGraphStorage(BaseGraphStorage):
364
- """基于Oracle的图存储模块"""
 
365
 
366
  def __post_init__(self):
367
  """从graphml文件加载图"""
 
361
 
362
  @dataclass
363
  class OracleGraphStorage(BaseGraphStorage):
364
+ # should pass db object to self.db
365
+ db: OracleDB = None
366
 
367
  def __post_init__(self):
368
  """从graphml文件加载图"""
lightrag/lightrag.py CHANGED
@@ -1,5 +1,6 @@
1
  import asyncio
2
  import os
 
3
  from dataclasses import asdict, dataclass, field
4
  from datetime import datetime
5
  from functools import partial
@@ -15,6 +16,7 @@ from .base import (
15
  QueryParam,
16
  StorageNameSpace,
17
  )
 
18
  from .namespace import NameSpace, make_namespace
19
  from .operate import (
20
  chunking_by_token_size,
@@ -35,6 +37,9 @@ from .utils import (
35
  set_logger,
36
  )
37
 
 
 
 
38
  # Storage type and implementation compatibility validation table
39
  STORAGE_IMPLEMENTATIONS = {
40
  "KV_STORAGE": {
@@ -389,6 +394,53 @@ class LightRAG:
389
  self.graph_storage_cls, global_config=global_config
390
  )
391
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
392
  self.json_doc_status_storage = self.key_string_value_json_storage_cls(
393
  namespace=self.namespace_prefix + "json_doc_status_storage",
394
  embedding_func=None,
 
1
  import asyncio
2
  import os
3
+ import configparser
4
  from dataclasses import asdict, dataclass, field
5
  from datetime import datetime
6
  from functools import partial
 
16
  QueryParam,
17
  StorageNameSpace,
18
  )
19
+ from .kg.oracle_impl import OracleDB
20
  from .namespace import NameSpace, make_namespace
21
  from .operate import (
22
  chunking_by_token_size,
 
37
  set_logger,
38
  )
39
 
40
+ config = configparser.ConfigParser()
41
+ config.read("config.ini", "utf-8")
42
+
43
  # Storage type and implementation compatibility validation table
44
  STORAGE_IMPLEMENTATIONS = {
45
  "KV_STORAGE": {
 
394
  self.graph_storage_cls, global_config=global_config
395
  )
396
 
397
+ # 检查是否使用了 Oracle 存储实现
398
+ if (
399
+ self.kv_storage == "OracleKVStorage"
400
+ or self.vector_storage == "OracleVectorDBStorage"
401
+ or self.graph_storage == "OracleGraphStorage"
402
+ ):
403
+ # 从环境变量或配置文件获取参数
404
+ dbconfig = {
405
+ "user": os.environ.get(
406
+ "ORACLE_USER", config.get("oracle", "user", fallback=None)
407
+ ),
408
+ "password": os.environ.get(
409
+ "ORACLE_PASSWORD",
410
+ config.get("oracle", "password", fallback=None),
411
+ ),
412
+ "dsn": os.environ.get(
413
+ "ORACLE_DSN", config.get("oracle", "dsn", fallback=None)
414
+ ),
415
+ "config_dir": os.environ.get(
416
+ "ORACLE_CONFIG_DIR",
417
+ config.get("oracle", "config_dir", fallback=None),
418
+ ),
419
+ "wallet_location": os.environ.get(
420
+ "ORACLE_WALLET_LOCATION",
421
+ config.get("oracle", "wallet_location", fallback=None),
422
+ ),
423
+ "wallet_password": os.environ.get(
424
+ "ORACLE_WALLET_PASSWORD",
425
+ config.get("oracle", "wallet_password", fallback=None),
426
+ ),
427
+ "workspace": os.environ.get(
428
+ "ORACLE_WORKSPACE",
429
+ config.get("oracle", "workspace", fallback="default"),
430
+ ),
431
+ }
432
+
433
+ # 初始化 OracleDB 对象
434
+ oracle_db = OracleDB(dbconfig)
435
+
436
+ # 只对 Oracle 实现的存储类注入 db 对象
437
+ if self.kv_storage == "OracleKVStorage":
438
+ self.key_string_value_json_storage_cls.db = oracle_db
439
+ if self.vector_storage == "OracleVectorDBStorage":
440
+ self.vector_db_storage_cls.db = oracle_db
441
+ if self.graph_storage == "OracleGraphStorage":
442
+ self.graph_storage_cls.db = oracle_db
443
+
444
  self.json_doc_status_storage = self.key_string_value_json_storage_cls(
445
  namespace=self.namespace_prefix + "json_doc_status_storage",
446
  embedding_func=None,