yangdx commited on
Commit
a7c0d7b
·
1 Parent(s): feb48a4

refactor: Implement dynamic database module imports

Browse files

- Consolidate database instance management
- Improve database management and error handling
- Enhance error handling and logging

Files changed (1) hide show
  1. lightrag/api/lightrag_server.py +76 -98
lightrag/api/lightrag_server.py CHANGED
@@ -41,25 +41,28 @@ from .ollama_api import (
41
  OllamaAPI,
42
  )
43
  from .ollama_api import ollama_server_infos
44
- from ..kg.postgres_impl import (
45
- PostgreSQLDB,
46
- PGKVStorage,
47
- PGVectorStorage,
48
- PGGraphStorage,
49
- PGDocStatusStorage,
50
- )
51
- from ..kg.oracle_impl import (
52
- OracleDB,
53
- OracleKVStorage,
54
- OracleVectorDBStorage,
55
- OracleGraphStorage,
56
- )
57
- from ..kg.tidb_impl import (
58
- TiDB,
59
- TiDBKVStorage,
60
- TiDBVectorDBStorage,
61
- TiDBGraphStorage,
62
- )
 
 
 
63
 
64
  # Load environment variables
65
  try:
@@ -333,28 +336,28 @@ def parse_args() -> argparse.Namespace:
333
  default=get_env_value(
334
  "LIGHTRAG_KV_STORAGE", DefaultRAGStorageConfig.KV_STORAGE
335
  ),
336
- help=f"KV存储实现 (default: {DefaultRAGStorageConfig.KV_STORAGE})",
337
  )
338
  parser.add_argument(
339
  "--doc-status-storage",
340
  default=get_env_value(
341
  "LIGHTRAG_DOC_STATUS_STORAGE", DefaultRAGStorageConfig.DOC_STATUS_STORAGE
342
  ),
343
- help=f"文档状态存储实现 (default: {DefaultRAGStorageConfig.DOC_STATUS_STORAGE})",
344
  )
345
  parser.add_argument(
346
  "--graph-storage",
347
  default=get_env_value(
348
  "LIGHTRAG_GRAPH_STORAGE", DefaultRAGStorageConfig.GRAPH_STORAGE
349
  ),
350
- help=f"图存储实现 (default: {DefaultRAGStorageConfig.GRAPH_STORAGE})",
351
  )
352
  parser.add_argument(
353
  "--vector-storage",
354
  default=get_env_value(
355
  "LIGHTRAG_VECTOR_STORAGE", DefaultRAGStorageConfig.VECTOR_STORAGE
356
  ),
357
- help=f"向量存储实现 (default: {DefaultRAGStorageConfig.VECTOR_STORAGE})",
358
  )
359
 
360
  # Bindings configuration
@@ -890,72 +893,47 @@ def create_app(args):
890
  async def lifespan(app: FastAPI):
891
  """Lifespan context manager for startup and shutdown events"""
892
  # Initialize database connections
893
- postgres_db = None
894
- oracle_db = None
895
- tidb_db = None
896
  # Store background tasks
897
  app.state.background_tasks = set()
898
 
899
  try:
900
- # Check if PostgreSQL is needed
901
- if any(
902
- isinstance(
903
- storage_instance,
904
- (PGKVStorage, PGVectorStorage, PGGraphStorage, PGDocStatusStorage),
905
- )
906
- for _, storage_instance in storage_instances
907
- ):
908
- postgres_db = PostgreSQLDB(_get_postgres_config())
909
- await postgres_db.initdb()
910
- await postgres_db.check_tables()
911
- for storage_name, storage_instance in storage_instances:
912
- if isinstance(
913
- storage_instance,
914
- (
915
- PGKVStorage,
916
- PGVectorStorage,
917
- PGGraphStorage,
918
- PGDocStatusStorage,
919
- ),
920
- ):
921
- storage_instance.db = postgres_db
922
- logger.info(f"Injected postgres_db to {storage_name}")
923
-
924
- # Check if Oracle is needed
925
- if any(
926
- isinstance(
927
- storage_instance,
928
- (OracleKVStorage, OracleVectorDBStorage, OracleGraphStorage),
929
- )
930
- for _, storage_instance in storage_instances
931
- ):
932
- oracle_db = OracleDB(_get_oracle_config())
933
- await oracle_db.check_tables()
934
- for storage_name, storage_instance in storage_instances:
935
- if isinstance(
936
- storage_instance,
937
- (OracleKVStorage, OracleVectorDBStorage, OracleGraphStorage),
938
- ):
939
- storage_instance.db = oracle_db
940
- logger.info(f"Injected oracle_db to {storage_name}")
941
-
942
- # Check if TiDB is needed
943
- if any(
944
- isinstance(
945
- storage_instance,
946
- (TiDBKVStorage, TiDBVectorDBStorage, TiDBGraphStorage),
947
- )
948
- for _, storage_instance in storage_instances
949
- ):
950
- tidb_db = TiDB(_get_tidb_config())
951
- await tidb_db.check_tables()
952
- for storage_name, storage_instance in storage_instances:
953
- if isinstance(
954
- storage_instance,
955
- (TiDBKVStorage, TiDBVectorDBStorage, TiDBGraphStorage),
956
- ):
957
- storage_instance.db = tidb_db
958
- logger.info(f"Injected tidb_db to {storage_name}")
959
 
960
  # Auto scan documents if enabled
961
  if args.auto_scan_at_startup:
@@ -980,18 +958,18 @@ def create_app(args):
980
  yield
981
 
982
  finally:
983
- # Cleanup database connections
984
- if postgres_db and hasattr(postgres_db, "pool"):
985
- await postgres_db.pool.close()
986
- logger.info("Closed PostgreSQL connection pool")
987
-
988
- if oracle_db and hasattr(oracle_db, "pool"):
989
- await oracle_db.pool.close()
990
- logger.info("Closed Oracle connection pool")
991
-
992
- if tidb_db and hasattr(tidb_db, "pool"):
993
- await tidb_db.pool.close()
994
- logger.info("Closed TiDB connection pool")
995
 
996
  # Initialize FastAPI
997
  app = FastAPI(
@@ -1311,7 +1289,7 @@ def create_app(args):
1311
  case ".pdf":
1312
  if not pm.is_installed("pypdf2"):
1313
  pm.install("pypdf2")
1314
- from PyPDF2 import PdfReader
1315
  from io import BytesIO
1316
 
1317
  pdf_file = BytesIO(file)
 
41
  OllamaAPI,
42
  )
43
  from .ollama_api import ollama_server_infos
44
+ def get_db_type_from_storage_class(class_name: str) -> str | None:
45
+ """Determine database type based on storage class name"""
46
+ if class_name.startswith("PG"):
47
+ return "postgres"
48
+ elif class_name.startswith("Oracle"):
49
+ return "oracle"
50
+ elif class_name.startswith("TiDB"):
51
+ return "tidb"
52
+ return None
53
+
54
+ def import_db_module(db_type: str):
55
+ """Dynamically import database module"""
56
+ if db_type == "postgres":
57
+ from ..kg.postgres_impl import PostgreSQLDB
58
+ return PostgreSQLDB
59
+ elif db_type == "oracle":
60
+ from ..kg.oracle_impl import OracleDB
61
+ return OracleDB
62
+ elif db_type == "tidb":
63
+ from ..kg.tidb_impl import TiDB
64
+ return TiDB
65
+ return None
66
 
67
  # Load environment variables
68
  try:
 
336
  default=get_env_value(
337
  "LIGHTRAG_KV_STORAGE", DefaultRAGStorageConfig.KV_STORAGE
338
  ),
339
+ help=f"KV storage implementation (default: {DefaultRAGStorageConfig.KV_STORAGE})",
340
  )
341
  parser.add_argument(
342
  "--doc-status-storage",
343
  default=get_env_value(
344
  "LIGHTRAG_DOC_STATUS_STORAGE", DefaultRAGStorageConfig.DOC_STATUS_STORAGE
345
  ),
346
+ help=f"Document status storage implementation (default: {DefaultRAGStorageConfig.DOC_STATUS_STORAGE})",
347
  )
348
  parser.add_argument(
349
  "--graph-storage",
350
  default=get_env_value(
351
  "LIGHTRAG_GRAPH_STORAGE", DefaultRAGStorageConfig.GRAPH_STORAGE
352
  ),
353
+ help=f"Graph storage implementation (default: {DefaultRAGStorageConfig.GRAPH_STORAGE})",
354
  )
355
  parser.add_argument(
356
  "--vector-storage",
357
  default=get_env_value(
358
  "LIGHTRAG_VECTOR_STORAGE", DefaultRAGStorageConfig.VECTOR_STORAGE
359
  ),
360
+ help=f"Vector storage implementation (default: {DefaultRAGStorageConfig.VECTOR_STORAGE})",
361
  )
362
 
363
  # Bindings configuration
 
893
  async def lifespan(app: FastAPI):
894
  """Lifespan context manager for startup and shutdown events"""
895
  # Initialize database connections
896
+ db_instances = {}
 
 
897
  # Store background tasks
898
  app.state.background_tasks = set()
899
 
900
  try:
901
+ # Check which database types are used
902
+ db_types = set()
903
+ for storage_name, storage_instance in storage_instances:
904
+ db_type = get_db_type_from_storage_class(storage_instance.__class__.__name__)
905
+ if db_type:
906
+ db_types.add(db_type)
907
+
908
+ # Import and initialize databases as needed
909
+ for db_type in db_types:
910
+ if db_type == "postgres":
911
+ DB = import_db_module("postgres")
912
+ db = DB(_get_postgres_config())
913
+ await db.initdb()
914
+ await db.check_tables()
915
+ db_instances["postgres"] = db
916
+ elif db_type == "oracle":
917
+ DB = import_db_module("oracle")
918
+ db = DB(_get_oracle_config())
919
+ await db.check_tables()
920
+ db_instances["oracle"] = db
921
+ elif db_type == "tidb":
922
+ DB = import_db_module("tidb")
923
+ db = DB(_get_tidb_config())
924
+ await db.check_tables()
925
+ db_instances["tidb"] = db
926
+
927
+ # Inject database instances into storage classes
928
+ for storage_name, storage_instance in storage_instances:
929
+ db_type = get_db_type_from_storage_class(storage_instance.__class__.__name__)
930
+ if db_type:
931
+ if db_type not in db_instances:
932
+ error_msg = f"Database type '{db_type}' is required by {storage_name} but not initialized"
933
+ logger.error(error_msg)
934
+ raise RuntimeError(error_msg)
935
+ storage_instance.db = db_instances[db_type]
936
+ logger.info(f"Injected {db_type} db to {storage_name}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
937
 
938
  # Auto scan documents if enabled
939
  if args.auto_scan_at_startup:
 
958
  yield
959
 
960
  finally:
961
+ # Clean up database connections
962
+ for db_type, db in db_instances.items():
963
+ if hasattr(db, "pool"):
964
+ await db.pool.close()
965
+ # Use more accurate database name display
966
+ db_names = {
967
+ "postgres": "PostgreSQL",
968
+ "oracle": "Oracle",
969
+ "tidb": "TiDB"
970
+ }
971
+ db_name = db_names.get(db_type, db_type)
972
+ logger.info(f"Closed {db_name} database connection pool")
973
 
974
  # Initialize FastAPI
975
  app = FastAPI(
 
1289
  case ".pdf":
1290
  if not pm.is_installed("pypdf2"):
1291
  pm.install("pypdf2")
1292
+ from PyPDF2 import PdfReader # type: ignore
1293
  from io import BytesIO
1294
 
1295
  pdf_file = BytesIO(file)