yangdx
commited on
Commit
·
a7c0d7b
1
Parent(s):
feb48a4
refactor: Implement dynamic database module imports
Browse files- Consolidate database instance management
- Improve database management and error handling
- Enhance error handling and logging
- lightrag/api/lightrag_server.py +76 -98
lightrag/api/lightrag_server.py
CHANGED
@@ -41,25 +41,28 @@ from .ollama_api import (
|
|
41 |
OllamaAPI,
|
42 |
)
|
43 |
from .ollama_api import ollama_server_infos
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
)
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
from ..kg.
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
|
|
|
|
|
|
63 |
|
64 |
# Load environment variables
|
65 |
try:
|
@@ -333,28 +336,28 @@ def parse_args() -> argparse.Namespace:
|
|
333 |
default=get_env_value(
|
334 |
"LIGHTRAG_KV_STORAGE", DefaultRAGStorageConfig.KV_STORAGE
|
335 |
),
|
336 |
-
help=f"KV
|
337 |
)
|
338 |
parser.add_argument(
|
339 |
"--doc-status-storage",
|
340 |
default=get_env_value(
|
341 |
"LIGHTRAG_DOC_STATUS_STORAGE", DefaultRAGStorageConfig.DOC_STATUS_STORAGE
|
342 |
),
|
343 |
-
help=f"
|
344 |
)
|
345 |
parser.add_argument(
|
346 |
"--graph-storage",
|
347 |
default=get_env_value(
|
348 |
"LIGHTRAG_GRAPH_STORAGE", DefaultRAGStorageConfig.GRAPH_STORAGE
|
349 |
),
|
350 |
-
help=f"
|
351 |
)
|
352 |
parser.add_argument(
|
353 |
"--vector-storage",
|
354 |
default=get_env_value(
|
355 |
"LIGHTRAG_VECTOR_STORAGE", DefaultRAGStorageConfig.VECTOR_STORAGE
|
356 |
),
|
357 |
-
help=f"
|
358 |
)
|
359 |
|
360 |
# Bindings configuration
|
@@ -890,72 +893,47 @@ def create_app(args):
|
|
890 |
async def lifespan(app: FastAPI):
|
891 |
"""Lifespan context manager for startup and shutdown events"""
|
892 |
# Initialize database connections
|
893 |
-
|
894 |
-
oracle_db = None
|
895 |
-
tidb_db = None
|
896 |
# Store background tasks
|
897 |
app.state.background_tasks = set()
|
898 |
|
899 |
try:
|
900 |
-
# Check
|
901 |
-
|
902 |
-
|
903 |
-
|
904 |
-
|
905 |
-
|
906 |
-
|
907 |
-
|
908 |
-
|
909 |
-
|
910 |
-
|
911 |
-
|
912 |
-
|
913 |
-
|
914 |
-
|
915 |
-
|
916 |
-
|
917 |
-
|
918 |
-
|
919 |
-
|
920 |
-
|
921 |
-
|
922 |
-
|
923 |
-
|
924 |
-
|
925 |
-
|
926 |
-
|
927 |
-
|
928 |
-
|
929 |
-
|
930 |
-
|
931 |
-
|
932 |
-
|
933 |
-
|
934 |
-
|
935 |
-
|
936 |
-
storage_instance,
|
937 |
-
(OracleKVStorage, OracleVectorDBStorage, OracleGraphStorage),
|
938 |
-
):
|
939 |
-
storage_instance.db = oracle_db
|
940 |
-
logger.info(f"Injected oracle_db to {storage_name}")
|
941 |
-
|
942 |
-
# Check if TiDB is needed
|
943 |
-
if any(
|
944 |
-
isinstance(
|
945 |
-
storage_instance,
|
946 |
-
(TiDBKVStorage, TiDBVectorDBStorage, TiDBGraphStorage),
|
947 |
-
)
|
948 |
-
for _, storage_instance in storage_instances
|
949 |
-
):
|
950 |
-
tidb_db = TiDB(_get_tidb_config())
|
951 |
-
await tidb_db.check_tables()
|
952 |
-
for storage_name, storage_instance in storage_instances:
|
953 |
-
if isinstance(
|
954 |
-
storage_instance,
|
955 |
-
(TiDBKVStorage, TiDBVectorDBStorage, TiDBGraphStorage),
|
956 |
-
):
|
957 |
-
storage_instance.db = tidb_db
|
958 |
-
logger.info(f"Injected tidb_db to {storage_name}")
|
959 |
|
960 |
# Auto scan documents if enabled
|
961 |
if args.auto_scan_at_startup:
|
@@ -980,18 +958,18 @@ def create_app(args):
|
|
980 |
yield
|
981 |
|
982 |
finally:
|
983 |
-
#
|
984 |
-
|
985 |
-
|
986 |
-
|
987 |
-
|
988 |
-
|
989 |
-
|
990 |
-
|
991 |
-
|
992 |
-
|
993 |
-
|
994 |
-
|
995 |
|
996 |
# Initialize FastAPI
|
997 |
app = FastAPI(
|
@@ -1311,7 +1289,7 @@ def create_app(args):
|
|
1311 |
case ".pdf":
|
1312 |
if not pm.is_installed("pypdf2"):
|
1313 |
pm.install("pypdf2")
|
1314 |
-
from PyPDF2 import PdfReader
|
1315 |
from io import BytesIO
|
1316 |
|
1317 |
pdf_file = BytesIO(file)
|
|
|
41 |
OllamaAPI,
|
42 |
)
|
43 |
from .ollama_api import ollama_server_infos
|
44 |
+
def get_db_type_from_storage_class(class_name: str) -> str | None:
|
45 |
+
"""Determine database type based on storage class name"""
|
46 |
+
if class_name.startswith("PG"):
|
47 |
+
return "postgres"
|
48 |
+
elif class_name.startswith("Oracle"):
|
49 |
+
return "oracle"
|
50 |
+
elif class_name.startswith("TiDB"):
|
51 |
+
return "tidb"
|
52 |
+
return None
|
53 |
+
|
54 |
+
def import_db_module(db_type: str):
|
55 |
+
"""Dynamically import database module"""
|
56 |
+
if db_type == "postgres":
|
57 |
+
from ..kg.postgres_impl import PostgreSQLDB
|
58 |
+
return PostgreSQLDB
|
59 |
+
elif db_type == "oracle":
|
60 |
+
from ..kg.oracle_impl import OracleDB
|
61 |
+
return OracleDB
|
62 |
+
elif db_type == "tidb":
|
63 |
+
from ..kg.tidb_impl import TiDB
|
64 |
+
return TiDB
|
65 |
+
return None
|
66 |
|
67 |
# Load environment variables
|
68 |
try:
|
|
|
336 |
default=get_env_value(
|
337 |
"LIGHTRAG_KV_STORAGE", DefaultRAGStorageConfig.KV_STORAGE
|
338 |
),
|
339 |
+
help=f"KV storage implementation (default: {DefaultRAGStorageConfig.KV_STORAGE})",
|
340 |
)
|
341 |
parser.add_argument(
|
342 |
"--doc-status-storage",
|
343 |
default=get_env_value(
|
344 |
"LIGHTRAG_DOC_STATUS_STORAGE", DefaultRAGStorageConfig.DOC_STATUS_STORAGE
|
345 |
),
|
346 |
+
help=f"Document status storage implementation (default: {DefaultRAGStorageConfig.DOC_STATUS_STORAGE})",
|
347 |
)
|
348 |
parser.add_argument(
|
349 |
"--graph-storage",
|
350 |
default=get_env_value(
|
351 |
"LIGHTRAG_GRAPH_STORAGE", DefaultRAGStorageConfig.GRAPH_STORAGE
|
352 |
),
|
353 |
+
help=f"Graph storage implementation (default: {DefaultRAGStorageConfig.GRAPH_STORAGE})",
|
354 |
)
|
355 |
parser.add_argument(
|
356 |
"--vector-storage",
|
357 |
default=get_env_value(
|
358 |
"LIGHTRAG_VECTOR_STORAGE", DefaultRAGStorageConfig.VECTOR_STORAGE
|
359 |
),
|
360 |
+
help=f"Vector storage implementation (default: {DefaultRAGStorageConfig.VECTOR_STORAGE})",
|
361 |
)
|
362 |
|
363 |
# Bindings configuration
|
|
|
893 |
async def lifespan(app: FastAPI):
|
894 |
"""Lifespan context manager for startup and shutdown events"""
|
895 |
# Initialize database connections
|
896 |
+
db_instances = {}
|
|
|
|
|
897 |
# Store background tasks
|
898 |
app.state.background_tasks = set()
|
899 |
|
900 |
try:
|
901 |
+
# Check which database types are used
|
902 |
+
db_types = set()
|
903 |
+
for storage_name, storage_instance in storage_instances:
|
904 |
+
db_type = get_db_type_from_storage_class(storage_instance.__class__.__name__)
|
905 |
+
if db_type:
|
906 |
+
db_types.add(db_type)
|
907 |
+
|
908 |
+
# Import and initialize databases as needed
|
909 |
+
for db_type in db_types:
|
910 |
+
if db_type == "postgres":
|
911 |
+
DB = import_db_module("postgres")
|
912 |
+
db = DB(_get_postgres_config())
|
913 |
+
await db.initdb()
|
914 |
+
await db.check_tables()
|
915 |
+
db_instances["postgres"] = db
|
916 |
+
elif db_type == "oracle":
|
917 |
+
DB = import_db_module("oracle")
|
918 |
+
db = DB(_get_oracle_config())
|
919 |
+
await db.check_tables()
|
920 |
+
db_instances["oracle"] = db
|
921 |
+
elif db_type == "tidb":
|
922 |
+
DB = import_db_module("tidb")
|
923 |
+
db = DB(_get_tidb_config())
|
924 |
+
await db.check_tables()
|
925 |
+
db_instances["tidb"] = db
|
926 |
+
|
927 |
+
# Inject database instances into storage classes
|
928 |
+
for storage_name, storage_instance in storage_instances:
|
929 |
+
db_type = get_db_type_from_storage_class(storage_instance.__class__.__name__)
|
930 |
+
if db_type:
|
931 |
+
if db_type not in db_instances:
|
932 |
+
error_msg = f"Database type '{db_type}' is required by {storage_name} but not initialized"
|
933 |
+
logger.error(error_msg)
|
934 |
+
raise RuntimeError(error_msg)
|
935 |
+
storage_instance.db = db_instances[db_type]
|
936 |
+
logger.info(f"Injected {db_type} db to {storage_name}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
937 |
|
938 |
# Auto scan documents if enabled
|
939 |
if args.auto_scan_at_startup:
|
|
|
958 |
yield
|
959 |
|
960 |
finally:
|
961 |
+
# Clean up database connections
|
962 |
+
for db_type, db in db_instances.items():
|
963 |
+
if hasattr(db, "pool"):
|
964 |
+
await db.pool.close()
|
965 |
+
# Use more accurate database name display
|
966 |
+
db_names = {
|
967 |
+
"postgres": "PostgreSQL",
|
968 |
+
"oracle": "Oracle",
|
969 |
+
"tidb": "TiDB"
|
970 |
+
}
|
971 |
+
db_name = db_names.get(db_type, db_type)
|
972 |
+
logger.info(f"Closed {db_name} database connection pool")
|
973 |
|
974 |
# Initialize FastAPI
|
975 |
app = FastAPI(
|
|
|
1289 |
case ".pdf":
|
1290 |
if not pm.is_installed("pypdf2"):
|
1291 |
pm.install("pypdf2")
|
1292 |
+
from PyPDF2 import PdfReader # type: ignore
|
1293 |
from io import BytesIO
|
1294 |
|
1295 |
pdf_file = BytesIO(file)
|